diff --git a/haystack_experimental/components/extractors/llm_metadata_extractor.py b/haystack_experimental/components/extractors/llm_metadata_extractor.py index 84896a6b..00786b11 100644 --- a/haystack_experimental/components/extractors/llm_metadata_extractor.py +++ b/haystack_experimental/components/extractors/llm_metadata_extractor.py @@ -75,7 +75,7 @@ class LLMMetadataExtractor: ```python from haystack import Document - from dc_custom_component.extractors.llm_metadata_extractor import TEGLLMMetadataExtractor + from haystack_experimental.components.extractors.llm_metadata_extractor import LLMMetadataExtractor NER_PROMPT = ''' -Goal- @@ -187,7 +187,7 @@ def __init__( # pylint: disable=R0917 ast = SandboxedEnvironment().parse(prompt) template_variables = meta.find_undeclared_variables(ast) variables = list(template_variables) - if len(variables) != 1 and variables[0] != "document": + if len(variables) > 1 or variables[0] != "document": raise ValueError( f"Prompt must have exactly one variable called 'document'. Found {','.join(variables)} in the prompt." ) diff --git a/test/components/extractors/test_llm_metadata_extractor.py b/test/components/extractors/test_llm_metadata_extractor.py index 094633fb..fb471f54 100644 --- a/test/components/extractors/test_llm_metadata_extractor.py +++ b/test/components/extractors/test_llm_metadata_extractor.py @@ -14,21 +14,19 @@ class TestLLMMetadataExtractor: def test_init_default(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") extractor = LLMMetadataExtractor( - prompt="prompt {{test}}", + prompt="prompt {{document.content}}", expected_keys=["key1", "key2"], generator_api=LLMProvider.OPENAI, - prompt_variable="test" ) assert isinstance(extractor.builder, PromptBuilder) assert extractor.generator_api == LLMProvider.OPENAI assert extractor.expected_keys == ["key1", "key2"] assert extractor.raise_on_failure is False - assert extractor.prompt_variable == "test" def test_init_with_parameters(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") extractor = LLMMetadataExtractor( - prompt="prompt {{test}}", + prompt="prompt {{document.content}}", expected_keys=["key1", "key2"], raise_on_failure=True, generator_api=LLMProvider.OPENAI, @@ -36,7 +34,6 @@ def test_init_with_parameters(self, monkeypatch): 'model': 'gpt-3.5-turbo', 'generation_kwargs': {"temperature": 0.5} }, - prompt_variable="test", page_range=['1-5'] ) assert isinstance(extractor.builder, PromptBuilder) @@ -53,31 +50,28 @@ def test_init_missing_prompt_variable(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") with pytest.raises(ValueError): _ = LLMMetadataExtractor( - prompt="prompt {{test}}", + prompt="prompt {{ wrong_variable }}", expected_keys=["key1", "key2"], generator_api=LLMProvider.OPENAI, - prompt_variable="test2" ) def test_to_dict_default_params(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") extractor = LLMMetadataExtractor( - prompt="some prompt that was used with the LLM {{test}}", + prompt="some prompt that was used with the LLM {{document.content}}", expected_keys=["key1", "key2"], generator_api=LLMProvider.OPENAI, - prompt_variable="test", generator_api_params={'model': 'gpt-4o-mini', 'generation_kwargs': {"temperature": 0.5}}, - raise_on_failure=True) - + raise_on_failure=True + ) extractor_dict = extractor.to_dict() assert extractor_dict == { 'type': 'haystack_experimental.components.extractors.llm_metadata_extractor.LLMMetadataExtractor', 'init_parameters': { - 'prompt': 'some prompt that was used with the LLM {{test}}', + 'prompt': 'some prompt that was used with the LLM {{document.content}}', 'expected_keys': ['key1', 'key2'], 'raise_on_failure': True, - 'prompt_variable': 'test', 'generator_api': 'openai', 'page_range': None, 'generator_api_params': { @@ -89,6 +83,7 @@ def test_to_dict_default_params(self, monkeypatch): 'streaming_callback': None, 'system_prompt': None, }, + 'max_workers': 3 } } @@ -97,10 +92,9 @@ def test_from_dict(self, monkeypatch): extractor_dict = { 'type': 'haystack_experimental.components.extractors.llm_metadata_extractor.LLMMetadataExtractor', 'init_parameters': { - 'prompt': 'some prompt that was used with the LLM {{test}}', + 'prompt': 'some prompt that was used with the LLM {{document.content}}', 'expected_keys': ['key1', 'key2'], 'raise_on_failure': True, - 'prompt_variable': 'test', 'generator_api': 'openai', 'generator_api_params': { 'api_base_url': None, @@ -116,32 +110,18 @@ def test_from_dict(self, monkeypatch): extractor = LLMMetadataExtractor.from_dict(extractor_dict) assert extractor.raise_on_failure is True assert extractor.expected_keys == ["key1", "key2"] - assert extractor.prompt == "some prompt that was used with the LLM {{test}}" + assert extractor.prompt == "some prompt that was used with the LLM {{document.content}}" assert extractor.generator_api == LLMProvider.OPENAI def test_output_invalid_json_raise_on_failure_true(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") extractor = LLMMetadataExtractor( - prompt="prompt {{test}}", - expected_keys=["key1", "key2"], + prompt="prompt {{document.content}}", generator_api=LLMProvider.OPENAI, - prompt_variable="test", raise_on_failure=True ) with pytest.raises(ValueError): - extractor.is_valid_json_and_has_expected_keys(expected=["entities"], received="""{"json": "output"}""") - - def test_output_valid_json_not_expected_keys(self, monkeypatch): - monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") - extractor = LLMMetadataExtractor( - prompt="prompt {{test}}", - expected_keys=["key1", "key2"], - generator_api=LLMProvider.OPENAI, - prompt_variable="test", - raise_on_failure=True - ) - with pytest.raises(ValueError): - extractor.is_valid_json_and_has_expected_keys(expected=["entities"], received="{'json': 'output'}") + extractor._extract_metadata(llm_answer="""{"json: "output"}""") @pytest.mark.integration @pytest.mark.skipif( @@ -154,40 +134,43 @@ def test_live_run(self): Document(content="Hugging Face is a company founded in Paris, France and is known for its Transformers library") ] - ner_prompt = """ - Given a text and a list of entity types, identify all entities of those types from the text. - - -Steps- - 1. Identify all entities. For each identified entity, extract the following information: - - entity_name: Name of the entity, capitalized - - entity_type: One of the following types: [organization, person, product, service, industry] - Format each entity as {"entity": , "entity_type": } - - 2. Return output in a single list with all the entities identified in steps 1. - - -Examples- - ###################### - Example 1: - entity_types: [organization, product, service, industry, investment strategy, market trend] - text: - Another area of strength is our co-brand issuance. Visa is the primary network partner for eight of the top 10 co-brand partnerships in the US today and we are pleased that Visa has finalized a multi-year extension of our successful credit co-branded partnership with Alaska Airlines, a portfolio that benefits from a loyal customer base and high cross-border usage. - We have also had significant co-brand momentum in CEMEA. First, we launched a new co-brand card in partnership with Qatar Airways, British Airways and the National Bank of Kuwait. Second, we expanded our strong global Marriott relationship to launch Qatar's first hospitality co-branded card with Qatar Islamic Bank. Across the United Arab Emirates, we now have exclusive agreements with all the leading airlines marked by a recent agreement with Emirates Skywards. - And we also signed an inaugural Airline co-brand agreement in Morocco with Royal Air Maroc. Now newer digital issuers are equally - ------------------------ - output: - {"entities": [{"entity": "Visa", "entity_type": "company"}, {"entity": "Alaska Airlines", "entity_type": "company"}, {"entity": "Qatar Airways", "entity_type": "company"}, {"entity": "British Airways", "entity_type": "company"}, {"entity": "National Bank of Kuwait", "entity_type": "company"}, {"entity": "Marriott", "entity_type": "company"}, {"entity": "Qatar Islamic Bank", "entity_type": "company"}, {"entity": "Emirates Skywards", "entity_type": "company"}, {"entity": "Royal Air Maroc", "entity_type": "company"}]} - ############################# - - -Real Data- - ###################### - entity_types: [company, organization, person, country, product, service] - text: {{input_text}} - ###################### - output: - """ + ner_prompt = """Given a text and a list of entity types, identify all entities of those types from the text. + +-Steps- +1. Identify all entities. For each identified entity, extract the following information: +- entity_name: Name of the entity, capitalized +- entity_type: One of the following types: [organization, person, product, service, industry] +Format each entity as {"entity": , "entity_type": } + +2. Return output in a single list with all the entities identified in steps 1. + +-Examples- +###################### +Example 1: +entity_types: [organization, product, service, industry, investment strategy, market trend] +text: +Another area of strength is our co-brand issuance. Visa is the primary network partner for eight of the top 10 co-brand partnerships in the US today and we are pleased that Visa has finalized a multi-year extension of our successful credit co-branded partnership with Alaska Airlines, a portfolio that benefits from a loyal customer base and high cross-border usage. +We have also had significant co-brand momentum in CEMEA. First, we launched a new co-brand card in partnership with Qatar Airways, British Airways and the National Bank of Kuwait. Second, we expanded our strong global Marriott relationship to launch Qatar's first hospitality co-branded card with Qatar Islamic Bank. Across the United Arab Emirates, we now have exclusive agreements with all the leading airlines marked by a recent agreement with Emirates Skywards. +And we also signed an inaugural Airline co-brand agreement in Morocco with Royal Air Maroc. Now newer digital issuers are equally +------------------------ +output: +{"entities": [{"entity": "Visa", "entity_type": "company"}, {"entity": "Alaska Airlines", "entity_type": "company"}, {"entity": "Qatar Airways", "entity_type": "company"}, {"entity": "British Airways", "entity_type": "company"}, {"entity": "National Bank of Kuwait", "entity_type": "company"}, {"entity": "Marriott", "entity_type": "company"}, {"entity": "Qatar Islamic Bank", "entity_type": "company"}, {"entity": "Emirates Skywards", "entity_type": "company"}, {"entity": "Royal Air Maroc", "entity_type": "company"}]} +############################# + +-Real Data- +###################### +entity_types: [company, organization, person, country, product, service] +text: {{input_text}} +###################### +output: +""" doc_store = InMemoryDocumentStore() - extractor = LLMMetadataExtractor(prompt=ner_prompt, expected_keys=["entities"], prompt_variable="input_text", generator_api=LLMProvider.OPENAI) + extractor = LLMMetadataExtractor( + prompt=ner_prompt, + expected_keys=["entities"], + generator_api=LLMProvider.OPENAI + ) writer = DocumentWriter(document_store=doc_store) pipeline = Pipeline() pipeline.add_component("extractor", extractor)