From 108ae8e6ee4d32faf75f7b0cf12fa7a92284a48f Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Mon, 24 Apr 2023 10:13:10 -0400 Subject: [PATCH] update --- kor/extraction/api.py | 10 ++++++++- .../extraction/test_extraction_with_chain.py | 21 ++++++++++++++++++- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/kor/extraction/api.py b/kor/extraction/api.py index 912e83c..1a3c72f 100644 --- a/kor/extraction/api.py +++ b/kor/extraction/api.py @@ -49,6 +49,7 @@ def create_extraction_chain( validator: Optional[Validator] = None, input_formatter: InputFormatter = None, instruction_template: Optional[PromptTemplate] = None, + verbose: Optional[bool] = None, **encoder_kwargs: Any, ) -> LLMChain: """Create an extraction chain. @@ -73,12 +74,13 @@ def create_extraction_chain( * "type_description": type description of the node (from TypeDescriptor) * "format_instructions": information on how to format the output (from Encoder) + verbose: if provided, sets the verbosity on the chain, otherwise default + verbosity of the chain will be used encoder_kwargs: Keyword arguments to pass to the encoder class Returns: A langchain chain - Examples: .. code-block:: python @@ -94,6 +96,11 @@ def create_extraction_chain( raise ValueError(f"node must be an Object got {type(node)}") encoder = initialize_encoder(encoder_or_encoder_class, node, **encoder_kwargs) type_descriptor_to_use = initialize_type_descriptors(type_descriptor) + + chain_kwargs = {} + if verbose is not None: + chain_kwargs["verbose"] = verbose + return LLMChain( llm=llm, prompt=create_langchain_prompt( @@ -104,6 +111,7 @@ def create_extraction_chain( instruction_template=instruction_template, input_formatter=input_formatter, ), + **chain_kwargs, ) diff --git a/tests/extraction/test_extraction_with_chain.py b/tests/extraction/test_extraction_with_chain.py index 46c86eb..d8bcad3 100644 --- a/tests/extraction/test_extraction_with_chain.py +++ b/tests/extraction/test_extraction_with_chain.py @@ -1,6 +1,7 @@ """Test that the extraction chain works as expected.""" -from typing import Any, Mapping +from typing import Any, Mapping, Optional +import langchain import pytest from langchain import PromptTemplate from langchain.chains import LLMChain @@ -102,6 +103,24 @@ def test_not_implemented_assertion_raised_for_csv(options: Mapping[str, Any]) -> create_extraction_chain(chat_model, **options) +@pytest.mark.parametrize("verbose", [True, False, None]) +def test_instantiation_with_verbose_flag(verbose: Optional[bool]) -> None: + """Create an extraction chain.""" + chat_model = ToyChatModel(response="hello") + chain = create_extraction_chain( + chat_model, + SIMPLE_OBJECT_SCHEMA, + encoder_or_encoder_class="json", + verbose=verbose, + ) + assert isinstance(chain, LLMChain) + if verbose is None: + expected_verbose = langchain.verbose + else: + expected_verbose = verbose + assert chain.verbose == expected_verbose + + def test_using_custom_template() -> None: """Create an extraction chain with a custom template.""" template = PromptTemplate(