Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sjrl committed Dec 9, 2024
1 parent bfad2bf commit a5bf68d
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class LLMMetadataExtractor:
```python
from haystack import Document
from dc_custom_component.extractors.llm_metadata_extractor import TEGLLMMetadataExtractor
from haystack_experimental.components.extractors.llm_metadata_extractor import LLMMetadataExtractor
NER_PROMPT = '''
-Goal-
Expand Down Expand Up @@ -187,7 +187,7 @@ def __init__( # pylint: disable=R0917
ast = SandboxedEnvironment().parse(prompt)
template_variables = meta.find_undeclared_variables(ast)
variables = list(template_variables)
if len(variables) != 1 and variables[0] != "document":
if len(variables) > 1 or variables[0] != "document":
raise ValueError(
f"Prompt must have exactly one variable called 'document'. Found {','.join(variables)} in the prompt."
)
Expand Down
111 changes: 47 additions & 64 deletions test/components/extractors/test_llm_metadata_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,26 @@ class TestLLMMetadataExtractor:
def test_init_default(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor = LLMMetadataExtractor(
prompt="prompt {{test}}",
prompt="prompt {{document.content}}",
expected_keys=["key1", "key2"],
generator_api=LLMProvider.OPENAI,
prompt_variable="test"
)
assert isinstance(extractor.builder, PromptBuilder)
assert extractor.generator_api == LLMProvider.OPENAI
assert extractor.expected_keys == ["key1", "key2"]
assert extractor.raise_on_failure is False
assert extractor.prompt_variable == "test"

def test_init_with_parameters(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor = LLMMetadataExtractor(
prompt="prompt {{test}}",
prompt="prompt {{document.content}}",
expected_keys=["key1", "key2"],
raise_on_failure=True,
generator_api=LLMProvider.OPENAI,
generator_api_params={
'model': 'gpt-3.5-turbo',
'generation_kwargs': {"temperature": 0.5}
},
prompt_variable="test",
page_range=['1-5']
)
assert isinstance(extractor.builder, PromptBuilder)
Expand All @@ -53,31 +50,28 @@ def test_init_missing_prompt_variable(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
with pytest.raises(ValueError):
_ = LLMMetadataExtractor(
prompt="prompt {{test}}",
prompt="prompt {{ wrong_variable }}",
expected_keys=["key1", "key2"],
generator_api=LLMProvider.OPENAI,
prompt_variable="test2"
)

def test_to_dict_default_params(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor = LLMMetadataExtractor(
prompt="some prompt that was used with the LLM {{test}}",
prompt="some prompt that was used with the LLM {{document.content}}",
expected_keys=["key1", "key2"],
generator_api=LLMProvider.OPENAI,
prompt_variable="test",
generator_api_params={'model': 'gpt-4o-mini', 'generation_kwargs': {"temperature": 0.5}},
raise_on_failure=True)

raise_on_failure=True
)
extractor_dict = extractor.to_dict()

assert extractor_dict == {
'type': 'haystack_experimental.components.extractors.llm_metadata_extractor.LLMMetadataExtractor',
'init_parameters': {
'prompt': 'some prompt that was used with the LLM {{test}}',
'prompt': 'some prompt that was used with the LLM {{document.content}}',
'expected_keys': ['key1', 'key2'],
'raise_on_failure': True,
'prompt_variable': 'test',
'generator_api': 'openai',
'page_range': None,
'generator_api_params': {
Expand All @@ -89,6 +83,7 @@ def test_to_dict_default_params(self, monkeypatch):
'streaming_callback': None,
'system_prompt': None,
},
'max_workers': 3
}
}

Expand All @@ -97,10 +92,9 @@ def test_from_dict(self, monkeypatch):
extractor_dict = {
'type': 'haystack_experimental.components.extractors.llm_metadata_extractor.LLMMetadataExtractor',
'init_parameters': {
'prompt': 'some prompt that was used with the LLM {{test}}',
'prompt': 'some prompt that was used with the LLM {{document.content}}',
'expected_keys': ['key1', 'key2'],
'raise_on_failure': True,
'prompt_variable': 'test',
'generator_api': 'openai',
'generator_api_params': {
'api_base_url': None,
Expand All @@ -116,32 +110,18 @@ def test_from_dict(self, monkeypatch):
extractor = LLMMetadataExtractor.from_dict(extractor_dict)
assert extractor.raise_on_failure is True
assert extractor.expected_keys == ["key1", "key2"]
assert extractor.prompt == "some prompt that was used with the LLM {{test}}"
assert extractor.prompt == "some prompt that was used with the LLM {{document.content}}"
assert extractor.generator_api == LLMProvider.OPENAI

def test_output_invalid_json_raise_on_failure_true(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor = LLMMetadataExtractor(
prompt="prompt {{test}}",
expected_keys=["key1", "key2"],
prompt="prompt {{document.content}}",
generator_api=LLMProvider.OPENAI,
prompt_variable="test",
raise_on_failure=True
)
with pytest.raises(ValueError):
extractor.is_valid_json_and_has_expected_keys(expected=["entities"], received="""{"json": "output"}""")

def test_output_valid_json_not_expected_keys(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor = LLMMetadataExtractor(
prompt="prompt {{test}}",
expected_keys=["key1", "key2"],
generator_api=LLMProvider.OPENAI,
prompt_variable="test",
raise_on_failure=True
)
with pytest.raises(ValueError):
extractor.is_valid_json_and_has_expected_keys(expected=["entities"], received="{'json': 'output'}")
extractor._extract_metadata(llm_answer="""{"json: "output"}""")

@pytest.mark.integration
@pytest.mark.skipif(
Expand All @@ -154,40 +134,43 @@ def test_live_run(self):
Document(content="Hugging Face is a company founded in Paris, France and is known for its Transformers library")
]

ner_prompt = """
Given a text and a list of entity types, identify all entities of those types from the text.
-Steps-
1. Identify all entities. For each identified entity, extract the following information:
- entity_name: Name of the entity, capitalized
- entity_type: One of the following types: [organization, person, product, service, industry]
Format each entity as {"entity": <entity_name>, "entity_type": <entity_type>}
2. Return output in a single list with all the entities identified in steps 1.
-Examples-
######################
Example 1:
entity_types: [organization, product, service, industry, investment strategy, market trend]
text:
Another area of strength is our co-brand issuance. Visa is the primary network partner for eight of the top 10 co-brand partnerships in the US today and we are pleased that Visa has finalized a multi-year extension of our successful credit co-branded partnership with Alaska Airlines, a portfolio that benefits from a loyal customer base and high cross-border usage.
We have also had significant co-brand momentum in CEMEA. First, we launched a new co-brand card in partnership with Qatar Airways, British Airways and the National Bank of Kuwait. Second, we expanded our strong global Marriott relationship to launch Qatar's first hospitality co-branded card with Qatar Islamic Bank. Across the United Arab Emirates, we now have exclusive agreements with all the leading airlines marked by a recent agreement with Emirates Skywards.
And we also signed an inaugural Airline co-brand agreement in Morocco with Royal Air Maroc. Now newer digital issuers are equally
------------------------
output:
{"entities": [{"entity": "Visa", "entity_type": "company"}, {"entity": "Alaska Airlines", "entity_type": "company"}, {"entity": "Qatar Airways", "entity_type": "company"}, {"entity": "British Airways", "entity_type": "company"}, {"entity": "National Bank of Kuwait", "entity_type": "company"}, {"entity": "Marriott", "entity_type": "company"}, {"entity": "Qatar Islamic Bank", "entity_type": "company"}, {"entity": "Emirates Skywards", "entity_type": "company"}, {"entity": "Royal Air Maroc", "entity_type": "company"}]}
#############################
-Real Data-
######################
entity_types: [company, organization, person, country, product, service]
text: {{input_text}}
######################
output:
"""
ner_prompt = """Given a text and a list of entity types, identify all entities of those types from the text.
-Steps-
1. Identify all entities. For each identified entity, extract the following information:
- entity_name: Name of the entity, capitalized
- entity_type: One of the following types: [organization, person, product, service, industry]
Format each entity as {"entity": <entity_name>, "entity_type": <entity_type>}
2. Return output in a single list with all the entities identified in steps 1.
-Examples-
######################
Example 1:
entity_types: [organization, product, service, industry, investment strategy, market trend]
text:
Another area of strength is our co-brand issuance. Visa is the primary network partner for eight of the top 10 co-brand partnerships in the US today and we are pleased that Visa has finalized a multi-year extension of our successful credit co-branded partnership with Alaska Airlines, a portfolio that benefits from a loyal customer base and high cross-border usage.
We have also had significant co-brand momentum in CEMEA. First, we launched a new co-brand card in partnership with Qatar Airways, British Airways and the National Bank of Kuwait. Second, we expanded our strong global Marriott relationship to launch Qatar's first hospitality co-branded card with Qatar Islamic Bank. Across the United Arab Emirates, we now have exclusive agreements with all the leading airlines marked by a recent agreement with Emirates Skywards.
And we also signed an inaugural Airline co-brand agreement in Morocco with Royal Air Maroc. Now newer digital issuers are equally
------------------------
output:
{"entities": [{"entity": "Visa", "entity_type": "company"}, {"entity": "Alaska Airlines", "entity_type": "company"}, {"entity": "Qatar Airways", "entity_type": "company"}, {"entity": "British Airways", "entity_type": "company"}, {"entity": "National Bank of Kuwait", "entity_type": "company"}, {"entity": "Marriott", "entity_type": "company"}, {"entity": "Qatar Islamic Bank", "entity_type": "company"}, {"entity": "Emirates Skywards", "entity_type": "company"}, {"entity": "Royal Air Maroc", "entity_type": "company"}]}
#############################
-Real Data-
######################
entity_types: [company, organization, person, country, product, service]
text: {{input_text}}
######################
output:
"""

doc_store = InMemoryDocumentStore()
extractor = LLMMetadataExtractor(prompt=ner_prompt, expected_keys=["entities"], prompt_variable="input_text", generator_api=LLMProvider.OPENAI)
extractor = LLMMetadataExtractor(
prompt=ner_prompt,
expected_keys=["entities"],
generator_api=LLMProvider.OPENAI
)
writer = DocumentWriter(document_store=doc_store)
pipeline = Pipeline()
pipeline.add_component("extractor", extractor)
Expand Down

0 comments on commit a5bf68d

Please sign in to comment.