From 34fb56b6334bb9f91f8bf65de3d82cd5955e2917 Mon Sep 17 00:00:00 2001 From: Gabriel Altay Date: Thu, 20 Apr 2023 10:15:41 -0400 Subject: [PATCH 01/16] fix copy/pasta typos wikipedia->arxiv (#3222) just updates a few module level docstrings from Wikipedia -> Arxiv --- langchain/tools/arxiv/tool.py | 2 +- tests/integration_tests/test_arxiv.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/langchain/tools/arxiv/tool.py b/langchain/tools/arxiv/tool.py index 3ebbc1328f439..2c117c8a87641 100644 --- a/langchain/tools/arxiv/tool.py +++ b/langchain/tools/arxiv/tool.py @@ -1,4 +1,4 @@ -"""Tool for the Wikipedia API.""" +"""Tool for the Arxiv API.""" from langchain.tools.base import BaseTool from langchain.utilities.arxiv import ArxivAPIWrapper diff --git a/tests/integration_tests/test_arxiv.py b/tests/integration_tests/test_arxiv.py index bff1d1ee8d155..d7f2fdd1be9d5 100644 --- a/tests/integration_tests/test_arxiv.py +++ b/tests/integration_tests/test_arxiv.py @@ -1,4 +1,4 @@ -"""Integration test for Wikipedia API Wrapper.""" +"""Integration test for Arxiv API Wrapper.""" import pytest from langchain.utilities import ArxivAPIWrapper From b7f2061736ce6009b9a1b50f92e10d2d4f49c5b3 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Thu, 20 Apr 2023 07:57:07 -0700 Subject: [PATCH 02/16] Harrison/google places (#3207) Co-authored-by: Cao Hoang <65607230+cnhhoang850@users.noreply.github.com> Co-authored-by: vowelparrot <130414180+vowelparrot@users.noreply.github.com> --- .../agents/tools/examples/google_places.ipynb | 105 ++++++++++++++++ langchain/tools/__init__.py | 2 + langchain/tools/google_places/__init__.py | 1 + langchain/tools/google_places/tool.py | 27 +++++ langchain/utilities/__init__.py | 2 + langchain/utilities/google_places_api.py | 112 ++++++++++++++++++ 6 files changed, 249 insertions(+) create mode 100644 docs/modules/agents/tools/examples/google_places.ipynb create mode 100644 langchain/tools/google_places/__init__.py create mode 100644 langchain/tools/google_places/tool.py create mode 100644 langchain/utilities/google_places_api.py diff --git a/docs/modules/agents/tools/examples/google_places.ipynb b/docs/modules/agents/tools/examples/google_places.ipynb new file mode 100644 index 0000000000000..68a398ff9affe --- /dev/null +++ b/docs/modules/agents/tools/examples/google_places.ipynb @@ -0,0 +1,105 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "487607cd", + "metadata": {}, + "source": [ + "# Google Places\n", + "\n", + "This notebook goes through how to use Google Places API" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "8690845f", + "metadata": {}, + "outputs": [], + "source": [ + "#!pip install googlemaps" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "fae31ef4", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ[\"GPLACES_API_KEY\"] = \"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "abb502b3", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.tools import GooglePlacesTool" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "a83a02ac", + "metadata": {}, + "outputs": [], + "source": [ + "places = GooglePlacesTool()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "2b65a285", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\"1. Delfina Restaurant\\nAddress: 3621 18th St, San Francisco, CA 94110, USA\\nPhone: (415) 552-4055\\nWebsite: https://www.delfinasf.com/\\n\\n\\n2. Piccolo Forno\\nAddress: 725 Columbus Ave, San Francisco, CA 94133, USA\\nPhone: (415) 757-0087\\nWebsite: https://piccolo-forno-sf.com/\\n\\n\\n3. L'Osteria del Forno\\nAddress: 519 Columbus Ave, San Francisco, CA 94133, USA\\nPhone: (415) 982-1124\\nWebsite: Unknown\\n\\n\\n4. Il Fornaio\\nAddress: 1265 Battery St, San Francisco, CA 94111, USA\\nPhone: (415) 986-0100\\nWebsite: https://www.ilfornaio.com/\\n\\n\"" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "places.run(\"al fornos\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "66d3da8a", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/tools/__init__.py b/langchain/tools/__init__.py index 44f225e1f28dd..3c034f83b848b 100644 --- a/langchain/tools/__init__.py +++ b/langchain/tools/__init__.py @@ -2,6 +2,7 @@ from langchain.tools.base import BaseTool from langchain.tools.ddg_search.tool import DuckDuckGoSearchTool +from langchain.tools.google_places.tool import GooglePlacesTool from langchain.tools.ifttt import IFTTTWebhook from langchain.tools.openapi.utils.api_models import APIOperation from langchain.tools.openapi.utils.openapi_utils import OpenAPISpec @@ -13,5 +14,6 @@ "AIPluginTool", "OpenAPISpec", "APIOperation", + "GooglePlacesTool", "DuckDuckGoSearchTool", ] diff --git a/langchain/tools/google_places/__init__.py b/langchain/tools/google_places/__init__.py new file mode 100644 index 0000000000000..e5b5d5046d86a --- /dev/null +++ b/langchain/tools/google_places/__init__.py @@ -0,0 +1 @@ +"""Google Places API Toolkit.""" diff --git a/langchain/tools/google_places/tool.py b/langchain/tools/google_places/tool.py new file mode 100644 index 0000000000000..31ae39dae80b5 --- /dev/null +++ b/langchain/tools/google_places/tool.py @@ -0,0 +1,27 @@ +"""Tool for the Google search API.""" + +from pydantic import Field + +from langchain.tools.base import BaseTool +from langchain.utilities.google_places_api import GooglePlacesAPIWrapper + + +class GooglePlacesTool(BaseTool): + """Tool that adds the capability to query the Google places API.""" + + name = "Google Places" + description = ( + "A wrapper around Google Places. " + "Useful for when you need to validate or " + "discover addressed from ambiguous text. " + "Input should be a search query." + ) + api_wrapper: GooglePlacesAPIWrapper = Field(default_factory=GooglePlacesAPIWrapper) + + def _run(self, query: str) -> str: + """Use the tool.""" + return self.api_wrapper.run(query) + + async def _arun(self, query: str) -> str: + """Use the tool asynchronously.""" + raise NotImplementedError("GooglePlacesRun does not support async") diff --git a/langchain/utilities/__init__.py b/langchain/utilities/__init__.py index a41aa842ed99b..f834601d77add 100644 --- a/langchain/utilities/__init__.py +++ b/langchain/utilities/__init__.py @@ -4,6 +4,7 @@ from langchain.utilities.arxiv import ArxivAPIWrapper from langchain.utilities.bash import BashProcess from langchain.utilities.bing_search import BingSearchAPIWrapper +from langchain.utilities.google_places_api import GooglePlacesAPIWrapper from langchain.utilities.google_search import GoogleSearchAPIWrapper from langchain.utilities.google_serper import GoogleSerperAPIWrapper from langchain.utilities.openweathermap import OpenWeatherMapAPIWrapper @@ -20,6 +21,7 @@ "TextRequestsWrapper", "GoogleSearchAPIWrapper", "GoogleSerperAPIWrapper", + "GooglePlacesAPIWrapper", "WolframAlphaAPIWrapper", "SerpAPIWrapper", "SearxSearchWrapper", diff --git a/langchain/utilities/google_places_api.py b/langchain/utilities/google_places_api.py new file mode 100644 index 0000000000000..585a52424a33c --- /dev/null +++ b/langchain/utilities/google_places_api.py @@ -0,0 +1,112 @@ +"""Chain that calls Google Places API. +""" + +import logging +from typing import Any, Dict, Optional + +from pydantic import BaseModel, Extra, root_validator + +from langchain.utils import get_from_dict_or_env + + +class GooglePlacesAPIWrapper(BaseModel): + """Wrapper around Google Places API. + + To use, you should have the ``googlemaps`` python package installed, + **an API key for the google maps platform**, + and the enviroment variable ''GPLACES_API_KEY'' + set with your API key , or pass 'gplaces_api_key' + as a named parameter to the constructor. + + By default, this will return the all the results on the input query. + You can use the top_k_results argument to limit the number of results. + + Example: + .. code-block:: python + + + from langchain import GooglePlacesAPIWrapper + gplaceapi = GooglePlacesAPIWrapper() + """ + + gplaces_api_key: Optional[str] = None + google_map_client: Any #: :meta private: + top_k_results: Optional[int] = None + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key is in your environment variable.""" + gplaces_api_key = get_from_dict_or_env( + values, "gplaces_api_key", "GPLACES_API_KEY" + ) + values["gplaces_api_key"] = gplaces_api_key + try: + import googlemaps + + values["google_map_client"] = googlemaps.Client(gplaces_api_key) + except ImportError: + raise ValueError( + "Could not import googlemaps python packge. " + "Please install it with `pip install googlemaps`." + ) + return values + + def run(self, query: str) -> str: + """Run Places search and get k number of places that exists that match.""" + search_results = self.google_map_client.places(query)["results"] + num_to_return = len(search_results) + + places = [] + + if num_to_return == 0: + return "Google Places did not find any places that match the description" + + num_to_return = ( + num_to_return + if self.top_k_results is None + else min(num_to_return, self.top_k_results) + ) + + for i in range(num_to_return): + result = search_results[i] + details = self.fetch_place_details(result["place_id"]) + + if details is not None: + places.append(details) + + return "\n".join([f"{i+1}. {item}" for i, item in enumerate(places)]) + + def fetch_place_details(self, place_id: str) -> Optional[str]: + try: + place_details = self.google_map_client.place(place_id) + formatted_details = self.format_place_details(place_details) + return formatted_details + except Exception as e: + logging.error(f"An Error occurred while fetching place details: {e}") + return None + + def format_place_details(self, place_details: Dict[str, Any]) -> Optional[str]: + try: + name = place_details.get("result", {}).get("name", "Unkown") + address = place_details.get("result", {}).get( + "formatted_address", "Unknown" + ) + phone_number = place_details.get("result", {}).get( + "formatted_phone_number", "Unknown" + ) + website = place_details.get("result", {}).get("website", "Unknown") + + formatted_details = ( + f"{name}\nAddress: {address}\n" + f"Phone: {phone_number}\nWebsite: {website}\n\n" + ) + return formatted_details + except Exception as e: + logging.error(f"An error occurred while formatting place details: {e}") + return None From b7dea80cbadb853915007be79dc6a35e5d89d40e Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Thu, 20 Apr 2023 08:30:38 -0700 Subject: [PATCH 03/16] bump version to 145 (#3229) --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 34252b19ed773..51878414f26cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain" -version = "0.0.144" +version = "0.0.145" description = "Building applications with LLMs through composability" authors = [] license = "MIT" @@ -140,7 +140,7 @@ llms = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "manifes qdrant = ["qdrant-client"] openai = ["openai"] cohere = ["cohere"] -all = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "jina", "manifest-ml", "elasticsearch", "opensearch-py", "google-search-results", "faiss-cpu", "sentence_transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch", "jinja2", "pinecone-client", "pinecone-text", "weaviate-client", "redis", "google-api-python-client", "wolframalpha", "qdrant-client", "tensorflow-text", "pypdf", "networkx", "nomic", "aleph-alpha-client", "deeplake", "pgvector", "psycopg2-binary", "boto3", "pyowm", "pytesseract", "html2text", "atlassian-python-api", "gptcache", "duckduckgo-search"] +all = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "jina", "manifest-ml", "elasticsearch", "opensearch-py", "google-search-results", "faiss-cpu", "sentence_transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch", "jinja2", "pinecone-client", "pinecone-text", "weaviate-client", "redis", "google-api-python-client", "wolframalpha", "qdrant-client", "tensorflow-text", "pypdf", "networkx", "nomic", "aleph-alpha-client", "deeplake", "pgvector", "psycopg2-binary", "boto3", "pyowm", "pytesseract", "html2text", "atlassian-python-api", "gptcache", "duckduckgo-search", "arxiv"] [tool.ruff] select = [ From d54b977d4e31f31765b8113a2c520058503d0bd9 Mon Sep 17 00:00:00 2001 From: Peter Stolz <50801264+PeterStolz@users.noreply.github.com> Date: Thu, 20 Apr 2023 19:46:51 +0200 Subject: [PATCH 04/16] Fix docstring of RetrievalQA (#3231) Structure changed an RetrievalQA now expects BaseRetriever not VectorStore --- langchain/chains/retrieval_qa/base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/langchain/chains/retrieval_qa/base.py b/langchain/chains/retrieval_qa/base.py index bcc26b16f3c08..dc1d68bfdf07d 100644 --- a/langchain/chains/retrieval_qa/base.py +++ b/langchain/chains/retrieval_qa/base.py @@ -154,8 +154,9 @@ class RetrievalQA(BaseRetrievalQA): from langchain.llms import OpenAI from langchain.chains import RetrievalQA from langchain.faiss import FAISS - vectordb = FAISS(...) - retrievalQA = RetrievalQA.from_llm(llm=OpenAI(), retriever=vectordb) + from langchain.vectorstores.base import VectorStoreRetriever + retriever = VectorStoreRetriever(vectorstore=FAISS(...)) + retrievalQA = RetrievalQA.from_llm(llm=OpenAI(), retriever=retriever) """ From 130e4b9fcb0a0fbea5096e45ffd933921e56621d Mon Sep 17 00:00:00 2001 From: leo-gan Date: Thu, 20 Apr 2023 10:47:16 -0700 Subject: [PATCH 05/16] fixed a link to the youtube page (#3232) A link to the `YouTube` page was missing on the `index` page. --- docs/index.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/index.rst b/docs/index.rst index 32dc3e870a92a..04e3abcb9df2f 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -177,4 +177,5 @@ Additional collection of resources we think may be useful as you develop your ap ./tracing.md ./use_cases/model_laboratory.ipynb Discord + ./youtube.md Production Support From 8f22949dc486e49625f4913c60093b4f65c7bac6 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Thu, 20 Apr 2023 11:53:23 -0700 Subject: [PATCH 06/16] update nnotebook title --- .../examples/discord_loader.ipynb | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/docs/modules/indexes/document_loaders/examples/discord_loader.ipynb b/docs/modules/indexes/document_loaders/examples/discord_loader.ipynb index d8f0e8deaf6ba..cd24804d2a70f 100644 --- a/docs/modules/indexes/document_loaders/examples/discord_loader.ipynb +++ b/docs/modules/indexes/document_loaders/examples/discord_loader.ipynb @@ -1,11 +1,10 @@ { "cells": [ { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ - "# How to download your Discord data\n", + "# Discord\n", "\n", "You can follow the below steps to download your Discord data:\n", "\n", @@ -65,10 +64,23 @@ } ], "metadata": { - "language_info": { - "name": "python" + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" }, - "orig_nbformat": 4 + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } }, "nbformat": 4, "nbformat_minor": 2 From daee0b2b9754b13d87b25c7dfbf5298ec22f3eb6 Mon Sep 17 00:00:00 2001 From: Zander Chase <130414180+vowelparrot@users.noreply.github.com> Date: Thu, 20 Apr 2023 13:31:30 -0700 Subject: [PATCH 07/16] Patch Chat History Formatting (#3236) While we work on solidifying the memory interfaces, handle common chat history formats. This may break linting on anyone who has been passing in `get_chat_history` . Somewhat handles #3077 Alternative to #3078 that updates the typing --- .../chains/conversational_retrieval/base.py | 30 ++++++++++++++----- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/langchain/chains/conversational_retrieval/base.py b/langchain/chains/conversational_retrieval/base.py index 97424ecbfe066..b7fb299e869d9 100644 --- a/langchain/chains/conversational_retrieval/base.py +++ b/langchain/chains/conversational_retrieval/base.py @@ -15,16 +15,32 @@ from langchain.chains.llm import LLMChain from langchain.chains.question_answering import load_qa_chain from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseLanguageModel, BaseRetriever, Document +from langchain.schema import BaseLanguageModel, BaseMessage, BaseRetriever, Document from langchain.vectorstores.base import VectorStore +# Depending on the memory type and configuration, the chat history format may differ. +# This needs to be consolidated. +CHAT_TURN_TYPE = Union[Tuple[str, str], BaseMessage] -def _get_chat_history(chat_history: List[Tuple[str, str]]) -> str: + +_ROLE_MAP = {"human": "Human: ", "ai": "Assistant: "} + + +def _get_chat_history(chat_history: List[CHAT_TURN_TYPE]) -> str: buffer = "" - for human_s, ai_s in chat_history: - human = "Human: " + human_s - ai = "Assistant: " + ai_s - buffer += "\n" + "\n".join([human, ai]) + for dialogue_turn in chat_history: + if isinstance(dialogue_turn, BaseMessage): + role_prefix = _ROLE_MAP.get(dialogue_turn.type, f"{dialogue_turn.type}: ") + buffer += f"\n{role_prefix}{dialogue_turn.content}" + elif isinstance(dialogue_turn, tuple): + human = "Human: " + dialogue_turn[0] + ai = "Assistant: " + dialogue_turn[1] + buffer += "\n" + "\n".join([human, ai]) + else: + raise ValueError( + f"Unsupported chat history format: {type(dialogue_turn)}." + f" Full chat history: {chat_history} " + ) return buffer @@ -35,7 +51,7 @@ class BaseConversationalRetrievalChain(Chain): question_generator: LLMChain output_key: str = "answer" return_source_documents: bool = False - get_chat_history: Optional[Callable[[Tuple[str, str]], str]] = None + get_chat_history: Optional[Callable[[CHAT_TURN_TYPE], str]] = None """Return the source documents.""" class Config: From 7d3e6389f26da27ef4fcb9466795fe4622ab1e5a Mon Sep 17 00:00:00 2001 From: Tom Dyson Date: Thu, 20 Apr 2023 22:02:20 +0100 Subject: [PATCH 08/16] Add DuckDB prompt (#3233) Adds a prompt template for the DuckDB SQL dialect. --- langchain/chains/sql_database/prompt.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/langchain/chains/sql_database/prompt.py b/langchain/chains/sql_database/prompt.py index 0ced8a22b0e01..8fd3b46aeeeed 100644 --- a/langchain/chains/sql_database/prompt.py +++ b/langchain/chains/sql_database/prompt.py @@ -40,6 +40,27 @@ output_parser=CommaSeparatedListOutputParser(), ) +_duckdb_prompt = """You are a DuckDB expert. Given an input question, first create a syntactically correct DuckDB query to run, then look at the results of the query and return the answer to the input question. +Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per DuckDB. You can order the results to return the most informative data in the database. +Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers. +Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. + +Use the following format: + +Question: "Question here" +SQLQuery: "SQL Query to run" +SQLResult: "Result of the SQLQuery" +Answer: "Final answer here" + +Only use the following tables: +{table_info} + +Question: {input}""" + +DUCKDB_PROMPT = PromptTemplate( + input_variables=["input", "table_info", "top_k"], + template=_duckdb_prompt, +) _googlesql_prompt = """You are a GoogleSQL expert. Given an input question, first create a syntactically correct GoogleSQL query to run, then look at the results of the query and return the answer to the input question. Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per GoogleSQL. You can order the results to return the most informative data in the database. @@ -201,6 +222,7 @@ SQL_PROMPTS = { + "duckdb": DUCKDB_PROMPT, "googlesql": GOOGLESQL_PROMPT, "mssql": MSSQL_PROMPT, "mysql": MYSQL_PROMPT, From ae528fd06e005992e6b0d6a623acb83de27c578d Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Thu, 20 Apr 2023 15:03:32 -0600 Subject: [PATCH 09/16] fix error msg ref to beautifulsoup4 (#3242) Co-authored-by: Daniel Chalef --- langchain/document_loaders/html_bs.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/langchain/document_loaders/html_bs.py b/langchain/document_loaders/html_bs.py index c90e32d709dc7..fc636367a31ca 100644 --- a/langchain/document_loaders/html_bs.py +++ b/langchain/document_loaders/html_bs.py @@ -24,7 +24,8 @@ def __init__( import bs4 # noqa:F401 except ImportError: raise ValueError( - "bs4 package not found, please install it with " "`pip install bs4`" + "beautifulsoup4 package not found, please install it with " + "`pip install beautifulsoup4`" ) self.file_path = file_path From 0e797a3ff993070440927aa7549493d3bf885eb4 Mon Sep 17 00:00:00 2001 From: Boris Feld Date: Thu, 20 Apr 2023 23:57:41 +0200 Subject: [PATCH 10/16] Fixing issue link for Comet callback (#3212) Sorry I fixed that link once but there was still a typo inside, this time it should be good. --- langchain/callbacks/comet_ml_callback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/langchain/callbacks/comet_ml_callback.py b/langchain/callbacks/comet_ml_callback.py index 7917d26a50537..877a8f903738d 100644 --- a/langchain/callbacks/comet_ml_callback.py +++ b/langchain/callbacks/comet_ml_callback.py @@ -130,7 +130,7 @@ def __init__( warning = ( "The comet_ml callback is currently in beta and is subject to change " "based on updates to `langchain`. Please report any issues to " - "https://github.com/comet-ml/issue_tracking/issues with the tag " + "https://github.com/comet-ml/issue-tracking/issues with the tag " "`langchain`." ) self.comet_ml.LOGGER.warning(warning) From 0684aa081a26b0773471e02f842c6af8618c4ebc Mon Sep 17 00:00:00 2001 From: Albert Castellana Date: Fri, 21 Apr 2023 00:20:21 +0200 Subject: [PATCH 11/16] Ecosystem/Yeager.ai (#3239) Added yeagerai.md to ecosystem --- docs/ecosystem/yeagerai.md | 43 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 docs/ecosystem/yeagerai.md diff --git a/docs/ecosystem/yeagerai.md b/docs/ecosystem/yeagerai.md new file mode 100644 index 0000000000000..6483cce900151 --- /dev/null +++ b/docs/ecosystem/yeagerai.md @@ -0,0 +1,43 @@ +# Yeager.ai + +This page covers how to use [Yeager.ai](https://yeager.ai) to generate LangChain tools and agents. + +## What is Yeager.ai? +Yeager.ai is an ecosystem designed to simplify the process of creating AI agents and tools. + +It features yAgents, a No-code LangChain Agent Builder, which enables users to build, test, and deploy AI solutions with ease. Leveraging the LangChain framework, yAgents allows seamless integration with various language models and resources, making it suitable for developers, researchers, and AI enthusiasts across diverse applications. + +## yAgents +Low code generative agent designed to help you build, prototype, and deploy Langchain tools with ease. + +### How to use? +``` +pip install yeagerai-agent +yeagerai-agent +``` +Go to http://127.0.0.1:7860 + +This will install the necessary dependencies and set up yAgents on your system. After the first run, yAgents will create a .env file where you can input your OpenAI API key. You can do the same directly from the Gradio interface under the tab "Settings". + +`OPENAI_API_KEY=` + +We recommend using GPT-4,. However, the tool can also work with GPT-3 if the problem is broken down sufficiently. + +### Creating and Executing Tools with yAgents +yAgents makes it easy to create and execute AI-powered tools. Here's a brief overview of the process: +1. Create a tool: To create a tool, provide a natural language prompt to yAgents. The prompt should clearly describe the tool's purpose and functionality. For example: +`create a tool that returns the n-th prime number` + +2. Load the tool into the toolkit: To load a tool into yAgents, simply provide a command to yAgents that says so. For example: +`load the tool that you just created it into your toolkit` + +3. Execute the tool: To run a tool or agent, simply provide a command to yAgents that includes the name of the tool and any required parameters. For example: +`generate the 50th prime number` + +You can see a video of how it works [here](https://www.youtube.com/watch?v=KA5hCM3RaWE). + +As you become more familiar with yAgents, you can create more advanced tools and agents to automate your work and enhance your productivity. + +For more information, see [yAgents' Github](https://github.com/yeagerai/yeagerai-agent) or our [docs](https://yeagerai.gitbook.io/docs/general/welcome-to-yeager.ai) + + From 2dbb5261b5ba047c4ce936b5c7c50ae4e8c95a88 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Thu, 20 Apr 2023 15:37:56 -0700 Subject: [PATCH 12/16] wikibase agent --- docs/{modules/agents => use_cases}/agents/wikibase_agent.ipynb | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename docs/{modules/agents => use_cases}/agents/wikibase_agent.ipynb (100%) diff --git a/docs/modules/agents/agents/wikibase_agent.ipynb b/docs/use_cases/agents/wikibase_agent.ipynb similarity index 100% rename from docs/modules/agents/agents/wikibase_agent.ipynb rename to docs/use_cases/agents/wikibase_agent.ipynb From 5ef2d1e2a196e41733b94e1e445a09ab49bc582d Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Thu, 20 Apr 2023 15:43:57 -0700 Subject: [PATCH 13/16] add to docs --- docs/use_cases/personal_assistants.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/use_cases/personal_assistants.md b/docs/use_cases/personal_assistants.md index ebf17072e23fe..615a6ce4419d4 100644 --- a/docs/use_cases/personal_assistants.md +++ b/docs/use_cases/personal_assistants.md @@ -20,3 +20,4 @@ Highlighting specific parts: Specific examples of this include: - [AI Plugins](agents/custom_agent_with_plugin_retrieval.ipynb): an implementation of an agent that is designed to be able to use all AI Plugins. +- [Wikibase Agent](agents/wikibase_agent.ipynb): an implementation of an agent that is designed to interact with Wikibase. From 3943759a90de0c6678b3fe64311f69029daf0dab Mon Sep 17 00:00:00 2001 From: Matt Robinson Date: Thu, 20 Apr 2023 18:51:49 -0400 Subject: [PATCH 14/16] feat: add loader for rich text files (#3227) ### Summary Adds a loader for rich text files. Requires `unstructured>=0.5.12`. ### Testing The following test uses the example RTF file from the [`unstructured` repo](https://github.com/Unstructured-IO/unstructured/tree/main/example-docs). ```python from langchain.document_loaders import UnstructuredRTFLoader loader = UnstructuredRTFLoader("fake-doc.rtf", mode="elements") docs = loader.load() docs[0].page_content ``` --- langchain/document_loaders/__init__.py | 2 ++ langchain/document_loaders/rtf.py | 28 ++++++++++++++++++++++++++ 2 files changed, 30 insertions(+) create mode 100644 langchain/document_loaders/rtf.py diff --git a/langchain/document_loaders/__init__.py b/langchain/document_loaders/__init__.py index ab5fc4b0de384..c4cc744838f54 100644 --- a/langchain/document_loaders/__init__.py +++ b/langchain/document_loaders/__init__.py @@ -57,6 +57,7 @@ from langchain.document_loaders.powerpoint import UnstructuredPowerPointLoader from langchain.document_loaders.readthedocs import ReadTheDocsLoader from langchain.document_loaders.roam import RoamLoader +from langchain.document_loaders.rtf import UnstructuredRTFLoader from langchain.document_loaders.s3_directory import S3DirectoryLoader from langchain.document_loaders.s3_file import S3FileLoader from langchain.document_loaders.sitemap import SitemapLoader @@ -106,6 +107,7 @@ "OutlookMessageLoader", "UnstructuredEPubLoader", "UnstructuredMarkdownLoader", + "UnstructuredRTFLoader", "RoamLoader", "YoutubeLoader", "S3FileLoader", diff --git a/langchain/document_loaders/rtf.py b/langchain/document_loaders/rtf.py new file mode 100644 index 0000000000000..c4113be206294 --- /dev/null +++ b/langchain/document_loaders/rtf.py @@ -0,0 +1,28 @@ +"""Loader that loads rich text files.""" +from typing import Any, List + +from langchain.document_loaders.unstructured import ( + UnstructuredFileLoader, + satisfies_min_unstructured_version, +) + + +class UnstructuredRTFLoader(UnstructuredFileLoader): + """Loader that uses unstructured to load rtf files.""" + + def __init__( + self, file_path: str, mode: str = "single", **unstructured_kwargs: Any + ): + min_unstructured_version = "0.5.12" + if not satisfies_min_unstructured_version(min_unstructured_version): + raise ValueError( + "Partitioning rtf files is only supported in " + f"unstructured>={min_unstructured_version}." + ) + + super().__init__(file_path=file_path, mode=mode, **unstructured_kwargs) + + def _get_elements(self) -> List: + from unstructured.partition.rtf import partition_rtf + + return partition_rtf(filename=self.file_path, **self.unstructured_kwargs) From 46542dc7745cc837dc549b57f4ed4928e0999833 Mon Sep 17 00:00:00 2001 From: Davis Chase <130488702+dev2049@users.noreply.github.com> Date: Thu, 20 Apr 2023 17:01:14 -0700 Subject: [PATCH 15/16] Contextual compression retriever (#2915) Co-authored-by: Harrison Chase --- .../examples/contextual-compression.ipynb | 371 ++++++++++++++++++ langchain/document_transformers.py | 100 +++++ langchain/math_utils.py | 22 ++ langchain/output_parsers/boolean.py | 29 ++ langchain/retrievers/__init__.py | 2 + .../retrievers/contextual_compression.py | 51 +++ .../document_compressors/__init__.py | 17 + .../retrievers/document_compressors/base.py | 61 +++ .../document_compressors/chain_extract.py | 77 ++++ .../chain_extract_prompt.py | 11 + .../document_compressors/chain_filter.py | 65 +++ .../chain_filter_prompt.py | 9 + .../document_compressors/embeddings_filter.py | 70 ++++ langchain/schema.py | 25 +- langchain/text_splitter.py | 19 +- langchain/vectorstores/utils.py | 29 +- .../integration_tests/retrievers/__init__.py | 0 .../document_compressors/__init__.py | 0 .../document_compressors/test_base.py | 28 ++ .../test_chain_extract.py | 36 ++ .../document_compressors/test_chain_filter.py | 17 + .../test_embeddings_filter.py | 39 ++ .../retrievers/test_contextual_compression.py | 25 ++ .../test_document_transformers.py | 31 ++ .../unit_tests/test_document_transformers.py | 15 + tests/unit_tests/test_math_utils.py | 39 ++ 26 files changed, 1158 insertions(+), 30 deletions(-) create mode 100644 docs/modules/indexes/retrievers/examples/contextual-compression.ipynb create mode 100644 langchain/document_transformers.py create mode 100644 langchain/math_utils.py create mode 100644 langchain/output_parsers/boolean.py create mode 100644 langchain/retrievers/contextual_compression.py create mode 100644 langchain/retrievers/document_compressors/__init__.py create mode 100644 langchain/retrievers/document_compressors/base.py create mode 100644 langchain/retrievers/document_compressors/chain_extract.py create mode 100644 langchain/retrievers/document_compressors/chain_extract_prompt.py create mode 100644 langchain/retrievers/document_compressors/chain_filter.py create mode 100644 langchain/retrievers/document_compressors/chain_filter_prompt.py create mode 100644 langchain/retrievers/document_compressors/embeddings_filter.py create mode 100644 tests/integration_tests/retrievers/__init__.py create mode 100644 tests/integration_tests/retrievers/document_compressors/__init__.py create mode 100644 tests/integration_tests/retrievers/document_compressors/test_base.py create mode 100644 tests/integration_tests/retrievers/document_compressors/test_chain_extract.py create mode 100644 tests/integration_tests/retrievers/document_compressors/test_chain_filter.py create mode 100644 tests/integration_tests/retrievers/document_compressors/test_embeddings_filter.py create mode 100644 tests/integration_tests/retrievers/test_contextual_compression.py create mode 100644 tests/integration_tests/test_document_transformers.py create mode 100644 tests/unit_tests/test_document_transformers.py create mode 100644 tests/unit_tests/test_math_utils.py diff --git a/docs/modules/indexes/retrievers/examples/contextual-compression.ipynb b/docs/modules/indexes/retrievers/examples/contextual-compression.ipynb new file mode 100644 index 0000000000000..9f299c6b0ac35 --- /dev/null +++ b/docs/modules/indexes/retrievers/examples/contextual-compression.ipynb @@ -0,0 +1,371 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "fc0db1bc", + "metadata": {}, + "source": [ + "# Contextual Compression Retriever\n", + "\n", + "This notebook introduces the concept of DocumentCompressors and the ContextualCompressionRetriever. The core idea is simple: given a specific query, we should be able to return only the documents relevant to that query, and only the parts of those documents that are relevant. The ContextualCompressionsRetriever is a wrapper for another retriever that iterates over the initial output of the base retriever and filters and compresses those initial documents, so that only the most relevant information is returned." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "28e8dc12", + "metadata": {}, + "outputs": [], + "source": [ + "# Helper function for printing docs\n", + "\n", + "def pretty_print_docs(docs):\n", + " print(f\"\\n{'-' * 100}\\n\".join([f\"Document {i+1}:\\n\\n\" + d.page_content for i, d in enumerate(docs)]))" + ] + }, + { + "cell_type": "markdown", + "id": "6fa3d916", + "metadata": {}, + "source": [ + "## Using a vanilla vector store retriever\n", + "Let's start by initializing a simple vector store retriever and storing the 2023 State of the Union speech (in chunks). We can see that given an example question our retriever returns one or two relevant docs and a few irrelevant docs. And even the relevant docs have a lot of irrelevant information in them." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "9fbcc58f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Document 1:\n", + "\n", + "Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while you’re at it, pass the Disclose Act so Americans can know who is funding our elections. \n", + "\n", + "Tonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \n", + "\n", + "One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \n", + "\n", + "And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence.\n", + "----------------------------------------------------------------------------------------------------\n", + "Document 2:\n", + "\n", + "A former top litigator in private practice. A former federal public defender. And from a family of public school educators and police officers. A consensus builder. Since she’s been nominated, she’s received a broad range of support—from the Fraternal Order of Police to former judges appointed by Democrats and Republicans. \n", + "\n", + "And if we are to advance liberty and justice, we need to secure the Border and fix the immigration system. \n", + "\n", + "We can do both. At our border, we’ve installed new technology like cutting-edge scanners to better detect drug smuggling. \n", + "\n", + "We’ve set up joint patrols with Mexico and Guatemala to catch more human traffickers. \n", + "\n", + "We’re putting in place dedicated immigration judges so families fleeing persecution and violence can have their cases heard faster. \n", + "\n", + "We’re securing commitments and supporting partners in South and Central America to host more refugees and secure their own borders.\n", + "----------------------------------------------------------------------------------------------------\n", + "Document 3:\n", + "\n", + "And for our LGBTQ+ Americans, let’s finally get the bipartisan Equality Act to my desk. The onslaught of state laws targeting transgender Americans and their families is wrong. \n", + "\n", + "As I said last year, especially to our younger transgender Americans, I will always have your back as your President, so you can be yourself and reach your God-given potential. \n", + "\n", + "While it often appears that we never agree, that isn’t true. I signed 80 bipartisan bills into law last year. From preventing government shutdowns to protecting Asian-Americans from still-too-common hate crimes to reforming military justice. \n", + "\n", + "And soon, we’ll strengthen the Violence Against Women Act that I first wrote three decades ago. It is important for us to show the nation that we can come together and do big things. \n", + "\n", + "So tonight I’m offering a Unity Agenda for the Nation. Four big things we can do together. \n", + "\n", + "First, beat the opioid epidemic.\n", + "----------------------------------------------------------------------------------------------------\n", + "Document 4:\n", + "\n", + "Tonight, I’m announcing a crackdown on these companies overcharging American businesses and consumers. \n", + "\n", + "And as Wall Street firms take over more nursing homes, quality in those homes has gone down and costs have gone up. \n", + "\n", + "That ends on my watch. \n", + "\n", + "Medicare is going to set higher standards for nursing homes and make sure your loved ones get the care they deserve and expect. \n", + "\n", + "We’ll also cut costs and keep the economy going strong by giving workers a fair shot, provide more training and apprenticeships, hire them based on their skills not degrees. \n", + "\n", + "Let’s pass the Paycheck Fairness Act and paid leave. \n", + "\n", + "Raise the minimum wage to $15 an hour and extend the Child Tax Credit, so no one has to raise a family in poverty. \n", + "\n", + "Let’s increase Pell Grants and increase our historic support of HBCUs, and invest in what Jill—our First Lady who teaches full-time—calls America’s best-kept secret: community colleges.\n" + ] + } + ], + "source": [ + "from langchain.text_splitter import CharacterTextSplitter\n", + "from langchain.embeddings import OpenAIEmbeddings\n", + "from langchain.document_loaders import TextLoader\n", + "from langchain.vectorstores import FAISS\n", + "\n", + "documents = TextLoader('../../../state_of_the_union.txt').load()\n", + "text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n", + "texts = text_splitter.split_documents(documents)\n", + "retriever = FAISS.from_documents(texts, OpenAIEmbeddings()).as_retriever()\n", + "\n", + "docs = retriever.get_relevant_documents(\"What did the president say about Ketanji Brown Jackson\")\n", + "pretty_print_docs(docs)" + ] + }, + { + "cell_type": "markdown", + "id": "b7648612", + "metadata": {}, + "source": [ + "## Adding contextual compression with an `LLMChainExtractor`\n", + "Now let's wrap our base retriever with a `ContextualCompressionRetriever`. We'll add an `LLMChainExtractor`, which will iterate over the initially returned documents and extract from each only the content that is relevant to the query." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "9a658023", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Document 1:\n", + "\n", + "\"One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \n", + "\n", + "And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence.\"\n", + "----------------------------------------------------------------------------------------------------\n", + "Document 2:\n", + "\n", + "\"A former top litigator in private practice. A former federal public defender. And from a family of public school educators and police officers. A consensus builder. Since she’s been nominated, she’s received a broad range of support—from the Fraternal Order of Police to former judges appointed by Democrats and Republicans.\"\n" + ] + } + ], + "source": [ + "from langchain.llms import OpenAI\n", + "from langchain.retrievers import ContextualCompressionRetriever\n", + "from langchain.retrievers.document_compressors import LLMChainExtractor\n", + "\n", + "llm = OpenAI(temperature=0)\n", + "compressor = LLMChainExtractor.from_llm(llm)\n", + "compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)\n", + "\n", + "compressed_docs = compression_retriever.get_relevant_documents(\"What did the president say about Ketanji Jackson Brown\")\n", + "pretty_print_docs(compressed_docs)" + ] + }, + { + "cell_type": "markdown", + "id": "2cd38f3a", + "metadata": {}, + "source": [ + "## More built-in compressors: filters\n", + "### `LLMChainFilter`\n", + "The `LLMChainFilter` is slightly simpler but more robust compressor that uses an LLM chain to decide which of the initially retrieved documents to filter out and which ones to return, without manipulating the document contents." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "b216a767", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Document 1:\n", + "\n", + "Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while you’re at it, pass the Disclose Act so Americans can know who is funding our elections. \n", + "\n", + "Tonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \n", + "\n", + "One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \n", + "\n", + "And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence.\n" + ] + } + ], + "source": [ + "from langchain.retrievers.document_compressors import LLMChainFilter\n", + "\n", + "_filter = LLMChainFilter.from_llm(llm)\n", + "compression_retriever = ContextualCompressionRetriever(base_compressor=_filter, base_retriever=retriever)\n", + "\n", + "compressed_docs = compression_retriever.get_relevant_documents(\"What did the president say about Ketanji Jackson Brown\")\n", + "pretty_print_docs(compressed_docs)" + ] + }, + { + "cell_type": "markdown", + "id": "8c709598", + "metadata": {}, + "source": [ + "### `EmbeddingsFilter`\n", + "\n", + "Making an extra LLM call over each retrieved document is expensive and slow. The `EmbeddingsFilter` provides a cheaper and faster option by embedding the documents and query and only returning those documents which have sufficiently similar embeddings to the query." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "6fbc801f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Document 1:\n", + "\n", + "Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while you’re at it, pass the Disclose Act so Americans can know who is funding our elections. \n", + "\n", + "Tonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \n", + "\n", + "One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \n", + "\n", + "And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence.\n", + "----------------------------------------------------------------------------------------------------\n", + "Document 2:\n", + "\n", + "A former top litigator in private practice. A former federal public defender. And from a family of public school educators and police officers. A consensus builder. Since she’s been nominated, she’s received a broad range of support—from the Fraternal Order of Police to former judges appointed by Democrats and Republicans. \n", + "\n", + "And if we are to advance liberty and justice, we need to secure the Border and fix the immigration system. \n", + "\n", + "We can do both. At our border, we’ve installed new technology like cutting-edge scanners to better detect drug smuggling. \n", + "\n", + "We’ve set up joint patrols with Mexico and Guatemala to catch more human traffickers. \n", + "\n", + "We’re putting in place dedicated immigration judges so families fleeing persecution and violence can have their cases heard faster. \n", + "\n", + "We’re securing commitments and supporting partners in South and Central America to host more refugees and secure their own borders.\n", + "----------------------------------------------------------------------------------------------------\n", + "Document 3:\n", + "\n", + "And for our LGBTQ+ Americans, let’s finally get the bipartisan Equality Act to my desk. The onslaught of state laws targeting transgender Americans and their families is wrong. \n", + "\n", + "As I said last year, especially to our younger transgender Americans, I will always have your back as your President, so you can be yourself and reach your God-given potential. \n", + "\n", + "While it often appears that we never agree, that isn’t true. I signed 80 bipartisan bills into law last year. From preventing government shutdowns to protecting Asian-Americans from still-too-common hate crimes to reforming military justice. \n", + "\n", + "And soon, we’ll strengthen the Violence Against Women Act that I first wrote three decades ago. It is important for us to show the nation that we can come together and do big things. \n", + "\n", + "So tonight I’m offering a Unity Agenda for the Nation. Four big things we can do together. \n", + "\n", + "First, beat the opioid epidemic.\n" + ] + } + ], + "source": [ + "from langchain.embeddings import OpenAIEmbeddings\n", + "from langchain.retrievers.document_compressors import EmbeddingsFilter\n", + "\n", + "embeddings = OpenAIEmbeddings()\n", + "embeddings_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.76)\n", + "compression_retriever = ContextualCompressionRetriever(base_compressor=embeddings_filter, base_retriever=retriever)\n", + "\n", + "compressed_docs = compression_retriever.get_relevant_documents(\"What did the president say about Ketanji Jackson Brown\")\n", + "pretty_print_docs(compressed_docs)" + ] + }, + { + "cell_type": "markdown", + "id": "07365d36", + "metadata": {}, + "source": [ + "# Stringing compressors and document transformers together\n", + "Using the `DocumentCompressorPipeline` we can also easily combine multiple compressors in sequence. Along with compressors we can add `BaseDocumentTransformer`s to our pipeline, which don't perform any contextual compression but simply perform some transformation on a set of documents. For example `TextSplitter`s can be used as document transformers to split documents into smaller pieces, and the `EmbeddingsRedundantFilter` can be used to filter out redundant documents based on embedding similarity between documents.\n", + "\n", + "Below we create a compressor pipeline by first splitting our docs into smaller chunks, then removing redundant documents, and then filtering based on relevance to the query." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "2a150a63", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.document_transformers import EmbeddingsRedundantFilter\n", + "from langchain.retrievers.document_compressors import DocumentCompressorPipeline\n", + "from langchain.text_splitter import CharacterTextSplitter\n", + "\n", + "splitter = CharacterTextSplitter(chunk_size=300, chunk_overlap=0, separator=\". \")\n", + "redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)\n", + "relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.76)\n", + "pipeline_compressor = DocumentCompressorPipeline(\n", + " transformers=[splitter, redundant_filter, relevant_filter]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "3ceab64a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Document 1:\n", + "\n", + "One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \n", + "\n", + "And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson\n", + "----------------------------------------------------------------------------------------------------\n", + "Document 2:\n", + "\n", + "As I said last year, especially to our younger transgender Americans, I will always have your back as your President, so you can be yourself and reach your God-given potential. \n", + "\n", + "While it often appears that we never agree, that isn’t true. I signed 80 bipartisan bills into law last year\n", + "----------------------------------------------------------------------------------------------------\n", + "Document 3:\n", + "\n", + "A former top litigator in private practice. A former federal public defender. And from a family of public school educators and police officers. A consensus builder\n" + ] + } + ], + "source": [ + "compression_retriever = ContextualCompressionRetriever(base_compressor=pipeline_compressor, base_retriever=retriever)\n", + "\n", + "compressed_docs = compression_retriever.get_relevant_documents(\"What did the president say about Ketanji Jackson Brown\")\n", + "pretty_print_docs(compressed_docs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8cfd9fc5", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/document_transformers.py b/langchain/document_transformers.py new file mode 100644 index 0000000000000..7f17cb689852e --- /dev/null +++ b/langchain/document_transformers.py @@ -0,0 +1,100 @@ +"""Transform documents""" +from typing import Any, Callable, List, Sequence + +import numpy as np +from pydantic import BaseModel, Field + +from langchain.embeddings.base import Embeddings +from langchain.math_utils import cosine_similarity +from langchain.schema import BaseDocumentTransformer, Document + + +class _DocumentWithState(Document): + """Wrapper for a document that includes arbitrary state.""" + + state: dict = Field(default_factory=dict) + """State associated with the document.""" + + def to_document(self) -> Document: + """Convert the DocumentWithState to a Document.""" + return Document(page_content=self.page_content, metadata=self.metadata) + + @classmethod + def from_document(cls, doc: Document) -> "_DocumentWithState": + """Create a DocumentWithState from a Document.""" + if isinstance(doc, cls): + return doc + return cls(page_content=doc.page_content, metadata=doc.metadata) + + +def get_stateful_documents( + documents: Sequence[Document], +) -> Sequence[_DocumentWithState]: + return [_DocumentWithState.from_document(doc) for doc in documents] + + +def _filter_similar_embeddings( + embedded_documents: List[List[float]], similarity_fn: Callable, threshold: float +) -> List[int]: + """Filter redundant documents based on the similarity of their embeddings.""" + similarity = np.tril(similarity_fn(embedded_documents, embedded_documents), k=-1) + redundant = np.where(similarity > threshold) + redundant_stacked = np.column_stack(redundant) + redundant_sorted = np.argsort(similarity[redundant])[::-1] + included_idxs = set(range(len(embedded_documents))) + for first_idx, second_idx in redundant_stacked[redundant_sorted]: + if first_idx in included_idxs and second_idx in included_idxs: + # Default to dropping the second document of any highly similar pair. + included_idxs.remove(second_idx) + return list(sorted(included_idxs)) + + +def _get_embeddings_from_stateful_docs( + embeddings: Embeddings, documents: Sequence[_DocumentWithState] +) -> List[List[float]]: + if len(documents) and "embedded_doc" in documents[0].state: + embedded_documents = [doc.state["embedded_doc"] for doc in documents] + else: + embedded_documents = embeddings.embed_documents( + [d.page_content for d in documents] + ) + for doc, embedding in zip(documents, embedded_documents): + doc.state["embedded_doc"] = embedding + return embedded_documents + + +class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel): + """Filter that drops redundant documents by comparing their embeddings.""" + + embeddings: Embeddings + """Embeddings to use for embedding document contents.""" + similarity_fn: Callable = cosine_similarity + """Similarity function for comparing documents. Function expected to take as input + two matrices (List[List[float]]) and return a matrix of scores where higher values + indicate greater similarity.""" + similarity_threshold: float = 0.95 + """Threshold for determining when two documents are similar enough + to be considered redundant.""" + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + def transform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + """Filter down documents.""" + stateful_documents = get_stateful_documents(documents) + embedded_documents = _get_embeddings_from_stateful_docs( + self.embeddings, stateful_documents + ) + included_idxs = _filter_similar_embeddings( + embedded_documents, self.similarity_fn, self.similarity_threshold + ) + return [stateful_documents[i] for i in sorted(included_idxs)] + + async def atransform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + raise NotImplementedError diff --git a/langchain/math_utils.py b/langchain/math_utils.py new file mode 100644 index 0000000000000..218af0475aea5 --- /dev/null +++ b/langchain/math_utils.py @@ -0,0 +1,22 @@ +"""Math utils.""" +from typing import List, Union + +import numpy as np + +Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray] + + +def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: + """Row-wise cosine similarity between two equal-width matrices.""" + if len(X) == 0 or len(Y) == 0: + return np.array([]) + X = np.array(X) + Y = np.array(Y) + if X.shape[1] != Y.shape[1]: + raise ValueError("Number of columns in X and Y must be the same.") + + X_norm = np.linalg.norm(X, axis=1) + Y_norm = np.linalg.norm(Y, axis=1) + similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm) + similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0 + return similarity diff --git a/langchain/output_parsers/boolean.py b/langchain/output_parsers/boolean.py new file mode 100644 index 0000000000000..40890a9d81364 --- /dev/null +++ b/langchain/output_parsers/boolean.py @@ -0,0 +1,29 @@ +from langchain.schema import BaseOutputParser + + +class BooleanOutputParser(BaseOutputParser[bool]): + true_val: str = "YES" + false_val: str = "NO" + + def parse(self, text: str) -> bool: + """Parse the output of an LLM call to a boolean. + + Args: + text: output of language model + + Returns: + boolean + + """ + cleaned_text = text.strip() + if cleaned_text not in (self.true_val, self.false_val): + raise ValueError( + f"BooleanOutputParser expected output value to either be " + f"{self.true_val} or {self.false_val}. Received {cleaned_text}." + ) + return cleaned_text == self.true_val + + @property + def _type(self) -> str: + """Snake-case string identifier for output parser type.""" + return "boolean_output_parser" diff --git a/langchain/retrievers/__init__.py b/langchain/retrievers/__init__.py index 869ea937c28e1..d89cf9d8fb152 100644 --- a/langchain/retrievers/__init__.py +++ b/langchain/retrievers/__init__.py @@ -1,4 +1,5 @@ from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever +from langchain.retrievers.contextual_compression import ContextualCompressionRetriever from langchain.retrievers.databerry import DataberryRetriever from langchain.retrievers.elastic_search_bm25 import ElasticSearchBM25Retriever from langchain.retrievers.metal import MetalRetriever @@ -13,6 +14,7 @@ __all__ = [ "ChatGPTPluginRetriever", + "ContextualCompressionRetriever", "RemoteLangChainRetriever", "PineconeHybridSearchRetriever", "MetalRetriever", diff --git a/langchain/retrievers/contextual_compression.py b/langchain/retrievers/contextual_compression.py new file mode 100644 index 0000000000000..788a391981e45 --- /dev/null +++ b/langchain/retrievers/contextual_compression.py @@ -0,0 +1,51 @@ +"""Retriever that wraps a base retriever and filters the results.""" +from typing import List + +from pydantic import BaseModel, Extra + +from langchain.retrievers.document_compressors.base import ( + BaseDocumentCompressor, +) +from langchain.schema import BaseRetriever, Document + + +class ContextualCompressionRetriever(BaseRetriever, BaseModel): + """Retriever that wraps a base retriever and compresses the results.""" + + base_compressor: BaseDocumentCompressor + """Compressor for compressing retrieved documents.""" + + base_retriever: BaseRetriever + """Base Retriever to use for getting relevant documents.""" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + def get_relevant_documents(self, query: str) -> List[Document]: + """Get documents relevant for a query. + + Args: + query: string to find relevant documents for + + Returns: + Sequence of relevant documents + """ + docs = self.base_retriever.get_relevant_documents(query) + compressed_docs = self.base_compressor.compress_documents(docs, query) + return list(compressed_docs) + + async def aget_relevant_documents(self, query: str) -> List[Document]: + """Get documents relevant for a query. + + Args: + query: string to find relevant documents for + + Returns: + List of relevant documents + """ + docs = await self.base_retriever.aget_relevant_documents(query) + compressed_docs = await self.base_compressor.acompress_documents(docs, query) + return list(compressed_docs) diff --git a/langchain/retrievers/document_compressors/__init__.py b/langchain/retrievers/document_compressors/__init__.py new file mode 100644 index 0000000000000..528eae7194511 --- /dev/null +++ b/langchain/retrievers/document_compressors/__init__.py @@ -0,0 +1,17 @@ +from langchain.retrievers.document_compressors.base import DocumentCompressorPipeline +from langchain.retrievers.document_compressors.chain_extract import ( + LLMChainExtractor, +) +from langchain.retrievers.document_compressors.chain_filter import ( + LLMChainFilter, +) +from langchain.retrievers.document_compressors.embeddings_filter import ( + EmbeddingsFilter, +) + +__all__ = [ + "DocumentCompressorPipeline", + "EmbeddingsFilter", + "LLMChainExtractor", + "LLMChainFilter", +] diff --git a/langchain/retrievers/document_compressors/base.py b/langchain/retrievers/document_compressors/base.py new file mode 100644 index 0000000000000..b42d95eadf013 --- /dev/null +++ b/langchain/retrievers/document_compressors/base.py @@ -0,0 +1,61 @@ +"""Interface for retrieved document compressors.""" +from abc import ABC, abstractmethod +from typing import List, Sequence, Union + +from pydantic import BaseModel + +from langchain.schema import BaseDocumentTransformer, Document + + +class BaseDocumentCompressor(BaseModel, ABC): + """""" + + @abstractmethod + def compress_documents( + self, documents: Sequence[Document], query: str + ) -> Sequence[Document]: + """Compress retrieved documents given the query context.""" + + @abstractmethod + async def acompress_documents( + self, documents: Sequence[Document], query: str + ) -> Sequence[Document]: + """Compress retrieved documents given the query context.""" + + +class DocumentCompressorPipeline(BaseDocumentCompressor): + """Document compressor that uses a pipeline of transformers.""" + + transformers: List[Union[BaseDocumentTransformer, BaseDocumentCompressor]] + """List of document filters that are chained together and run in sequence.""" + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + def compress_documents( + self, documents: Sequence[Document], query: str + ) -> Sequence[Document]: + """Transform a list of documents.""" + for _transformer in self.transformers: + if isinstance(_transformer, BaseDocumentCompressor): + documents = _transformer.compress_documents(documents, query) + elif isinstance(_transformer, BaseDocumentTransformer): + documents = _transformer.transform_documents(documents) + else: + raise ValueError(f"Got unexpected transformer type: {_transformer}") + return documents + + async def acompress_documents( + self, documents: Sequence[Document], query: str + ) -> Sequence[Document]: + """Compress retrieved documents given the query context.""" + for _transformer in self.transformers: + if isinstance(_transformer, BaseDocumentCompressor): + documents = await _transformer.acompress_documents(documents, query) + elif isinstance(_transformer, BaseDocumentTransformer): + documents = await _transformer.atransform_documents(documents) + else: + raise ValueError(f"Got unexpected transformer type: {_transformer}") + return documents diff --git a/langchain/retrievers/document_compressors/chain_extract.py b/langchain/retrievers/document_compressors/chain_extract.py new file mode 100644 index 0000000000000..6f638559443f1 --- /dev/null +++ b/langchain/retrievers/document_compressors/chain_extract.py @@ -0,0 +1,77 @@ +"""DocumentFilter that uses an LLM chain to extract the relevant parts of documents.""" +from typing import Any, Callable, Dict, Optional, Sequence + +from langchain import LLMChain, PromptTemplate +from langchain.retrievers.document_compressors.base import ( + BaseDocumentCompressor, +) +from langchain.retrievers.document_compressors.chain_extract_prompt import ( + prompt_template, +) +from langchain.schema import BaseLanguageModel, BaseOutputParser, Document + + +def default_get_input(query: str, doc: Document) -> Dict[str, Any]: + """Return the compression chain input.""" + return {"question": query, "context": doc.page_content} + + +class NoOutputParser(BaseOutputParser[str]): + """Parse outputs that could return a null string of some sort.""" + + no_output_str: str = "NO_OUTPUT" + + def parse(self, text: str) -> str: + cleaned_text = text.strip() + if cleaned_text == self.no_output_str: + return "" + return cleaned_text + + +def _get_default_chain_prompt() -> PromptTemplate: + output_parser = NoOutputParser() + template = prompt_template.format(no_output_str=output_parser.no_output_str) + return PromptTemplate( + template=template, + input_variables=["question", "context"], + output_parser=output_parser, + ) + + +class LLMChainExtractor(BaseDocumentCompressor): + llm_chain: LLMChain + """LLM wrapper to use for compressing documents.""" + + get_input: Callable[[str, Document], dict] = default_get_input + """Callable for constructing the chain input from the query and a Document.""" + + def compress_documents( + self, documents: Sequence[Document], query: str + ) -> Sequence[Document]: + """Compress page content of raw documents.""" + compressed_docs = [] + for doc in documents: + _input = self.get_input(query, doc) + output = self.llm_chain.predict_and_parse(**_input) + if len(output) == 0: + continue + compressed_docs.append(Document(page_content=output, metadata=doc.metadata)) + return compressed_docs + + async def acompress_documents( + self, documents: Sequence[Document], query: str + ) -> Sequence[Document]: + raise NotImplementedError + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + prompt: Optional[PromptTemplate] = None, + get_input: Optional[Callable[[str, Document], str]] = None, + ) -> "LLMChainExtractor": + """Initialize from LLM.""" + _prompt = prompt if prompt is not None else _get_default_chain_prompt() + _get_input = get_input if get_input is not None else default_get_input + llm_chain = LLMChain(llm=llm, prompt=_prompt) + return cls(llm_chain=llm_chain, get_input=_get_input) diff --git a/langchain/retrievers/document_compressors/chain_extract_prompt.py b/langchain/retrievers/document_compressors/chain_extract_prompt.py new file mode 100644 index 0000000000000..c27b8770cb4b0 --- /dev/null +++ b/langchain/retrievers/document_compressors/chain_extract_prompt.py @@ -0,0 +1,11 @@ +# flake8: noqa +prompt_template = """Given the following question and context, extract any part of the context *AS IS* that is relevant to answer the question. If none of the context is relevant return {no_output_str}. + +Remember, *DO NOT* edit the extracted parts of the context. + +> Question: {{question}} +> Context: +>>> +{{context}} +>>> +Extracted relevant parts:""" diff --git a/langchain/retrievers/document_compressors/chain_filter.py b/langchain/retrievers/document_compressors/chain_filter.py new file mode 100644 index 0000000000000..f5e33e6bf65ab --- /dev/null +++ b/langchain/retrievers/document_compressors/chain_filter.py @@ -0,0 +1,65 @@ +"""Filter that uses an LLM to drop documents that aren't relevant to the query.""" +from typing import Any, Callable, Dict, Optional, Sequence + +from langchain import BasePromptTemplate, LLMChain, PromptTemplate +from langchain.output_parsers.boolean import BooleanOutputParser +from langchain.retrievers.document_compressors.base import ( + BaseDocumentCompressor, +) +from langchain.retrievers.document_compressors.chain_filter_prompt import ( + prompt_template, +) +from langchain.schema import BaseLanguageModel, Document + + +def _get_default_chain_prompt() -> PromptTemplate: + return PromptTemplate( + template=prompt_template, + input_variables=["question", "context"], + output_parser=BooleanOutputParser(), + ) + + +def default_get_input(query: str, doc: Document) -> Dict[str, Any]: + """Return the compression chain input.""" + return {"question": query, "context": doc.page_content} + + +class LLMChainFilter(BaseDocumentCompressor): + """Filter that drops documents that aren't relevant to the query.""" + + llm_chain: LLMChain + """LLM wrapper to use for filtering documents. + The chain prompt is expected to have a BooleanOutputParser.""" + + get_input: Callable[[str, Document], dict] = default_get_input + """Callable for constructing the chain input from the query and a Document.""" + + def compress_documents( + self, documents: Sequence[Document], query: str + ) -> Sequence[Document]: + """Filter down documents based on their relevance to the query.""" + filtered_docs = [] + for doc in documents: + _input = self.get_input(query, doc) + include_doc = self.llm_chain.predict_and_parse(**_input) + if include_doc: + filtered_docs.append(doc) + return filtered_docs + + async def acompress_documents( + self, documents: Sequence[Document], query: str + ) -> Sequence[Document]: + """Filter down documents.""" + raise NotImplementedError + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + prompt: Optional[BasePromptTemplate] = None, + **kwargs: Any + ) -> "LLMChainFilter": + _prompt = prompt if prompt is not None else _get_default_chain_prompt() + llm_chain = LLMChain(llm=llm, prompt=_prompt) + return cls(llm_chain=llm_chain, **kwargs) diff --git a/langchain/retrievers/document_compressors/chain_filter_prompt.py b/langchain/retrievers/document_compressors/chain_filter_prompt.py new file mode 100644 index 0000000000000..5376dfa2a1859 --- /dev/null +++ b/langchain/retrievers/document_compressors/chain_filter_prompt.py @@ -0,0 +1,9 @@ +# flake8: noqa +prompt_template = """Given the following question and context, return YES if the context is relevant to the question and NO if it isn't. + +> Question: {question} +> Context: +>>> +{context} +>>> +> Relevant (YES / NO):""" diff --git a/langchain/retrievers/document_compressors/embeddings_filter.py b/langchain/retrievers/document_compressors/embeddings_filter.py new file mode 100644 index 0000000000000..543380189d8da --- /dev/null +++ b/langchain/retrievers/document_compressors/embeddings_filter.py @@ -0,0 +1,70 @@ +"""Document compressor that uses embeddings to drop documents unrelated to the query.""" +from typing import Callable, Dict, Optional, Sequence + +import numpy as np +from pydantic import root_validator + +from langchain.document_transformers import ( + _get_embeddings_from_stateful_docs, + get_stateful_documents, +) +from langchain.embeddings.base import Embeddings +from langchain.math_utils import cosine_similarity +from langchain.retrievers.document_compressors.base import ( + BaseDocumentCompressor, +) +from langchain.schema import Document + + +class EmbeddingsFilter(BaseDocumentCompressor): + embeddings: Embeddings + """Embeddings to use for embedding document contents and queries.""" + similarity_fn: Callable = cosine_similarity + """Similarity function for comparing documents. Function expected to take as input + two matrices (List[List[float]]) and return a matrix of scores where higher values + indicate greater similarity.""" + k: Optional[int] = 20 + """The number of relevant documents to return. Can be set to None, in which case + `similarity_threshold` must be specified. Defaults to 20.""" + similarity_threshold: Optional[float] + """Threshold for determining when two documents are similar enough + to be considered redundant. Defaults to None, must be specified if `k` is set + to None.""" + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + @root_validator() + def validate_params(cls, values: Dict) -> Dict: + """Validate similarity parameters.""" + if values["k"] is None and values["similarity_threshold"] is None: + raise ValueError("Must specify one of `k` or `similarity_threshold`.") + return values + + def compress_documents( + self, documents: Sequence[Document], query: str + ) -> Sequence[Document]: + """Filter documents based on similarity of their embeddings to the query.""" + stateful_documents = get_stateful_documents(documents) + embedded_documents = _get_embeddings_from_stateful_docs( + self.embeddings, stateful_documents + ) + embedded_query = self.embeddings.embed_query(query) + similarity = self.similarity_fn([embedded_query], embedded_documents)[0] + included_idxs = np.arange(len(embedded_documents)) + if self.k is not None: + included_idxs = np.argsort(similarity)[::-1][: self.k] + if self.similarity_threshold is not None: + similar_enough = np.where( + similarity[included_idxs] > self.similarity_threshold + ) + included_idxs = included_idxs[similar_enough] + return [stateful_documents[i] for i in included_idxs] + + async def acompress_documents( + self, documents: Sequence[Document], query: str + ) -> Sequence[Document]: + """Filter down documents.""" + raise NotImplementedError diff --git a/langchain/schema.py b/langchain/schema.py index 65f530948c356..821dc70ad6037 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -2,7 +2,17 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Dict, Generic, List, NamedTuple, Optional, TypeVar, Union +from typing import ( + Any, + Dict, + Generic, + List, + NamedTuple, + Optional, + Sequence, + TypeVar, + Union, +) from pydantic import BaseModel, Extra, Field, root_validator @@ -394,16 +404,17 @@ class OutputParserException(Exception): pass -D = TypeVar("D", bound=Document) - - -class BaseDocumentTransformer(ABC, Generic[D]): +class BaseDocumentTransformer(ABC): """Base interface for transforming documents.""" @abstractmethod - def transform_documents(self, documents: List[D], **kwargs: Any) -> List[D]: + def transform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: """Transform a list of documents.""" @abstractmethod - async def atransform_documents(self, documents: List[D], **kwargs: Any) -> List[D]: + async def atransform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: """Asynchronously transform a list of documents.""" diff --git a/langchain/text_splitter.py b/langchain/text_splitter.py index b9c3c24b599d0..7afd8b25d10bd 100644 --- a/langchain/text_splitter.py +++ b/langchain/text_splitter.py @@ -13,6 +13,7 @@ List, Literal, Optional, + Sequence, Union, ) @@ -22,7 +23,7 @@ logger = logging.getLogger(__name__) -class TextSplitter(BaseDocumentTransformer[Document], ABC): +class TextSplitter(BaseDocumentTransformer, ABC): """Interface for splitting text into chunks.""" def __init__( @@ -63,7 +64,7 @@ def split_documents(self, documents: List[Document]) -> List[Document]: """Split documents.""" texts = [doc.page_content for doc in documents] metadatas = [doc.metadata for doc in documents] - return self.create_documents(texts, metadatas) + return self.create_documents(texts, metadatas=metadatas) def _join_docs(self, docs: List[str], separator: str) -> Optional[str]: text = separator.join(docs) @@ -173,15 +174,15 @@ def _tiktoken_encoder(text: str, **kwargs: Any) -> int: return cls(length_function=_tiktoken_encoder, **kwargs) def transform_documents( - self, documents: List[Document], **kwargs: Any - ) -> List[Document]: - """Transform list of documents by splitting them.""" - return self.split_documents(documents) + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + """Transform sequence of documents by splitting them.""" + return self.split_documents(list(documents)) async def atransform_documents( - self, documents: List[Document], **kwargs: Any - ) -> List[Document]: - """Asynchronously transform a list of documents by splitting them.""" + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + """Asynchronously transform a sequence of documents by splitting them.""" raise NotImplementedError diff --git a/langchain/vectorstores/utils.py b/langchain/vectorstores/utils.py index e34a7703a7ead..50e8ae6cae838 100644 --- a/langchain/vectorstores/utils.py +++ b/langchain/vectorstores/utils.py @@ -4,10 +4,7 @@ import numpy as np - -def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float: - """Calculate cosine similarity with numpy.""" - return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) +from langchain.math_utils import cosine_similarity def maximal_marginal_relevance( @@ -17,22 +14,26 @@ def maximal_marginal_relevance( k: int = 4, ) -> List[int]: """Calculate maximal marginal relevance.""" - idxs: List[int] = [] - while len(idxs) < k: + if min(k, len(embedding_list)) <= 0: + return [] + similarity_to_query = cosine_similarity([query_embedding], embedding_list)[0] + most_similar = int(np.argmax(similarity_to_query)) + idxs = [most_similar] + selected = np.array([embedding_list[most_similar]]) + while len(idxs) < min(k, len(embedding_list)): best_score = -np.inf idx_to_add = -1 - for i, emb in enumerate(embedding_list): + similarity_to_selected = cosine_similarity(embedding_list, selected) + for i, query_score in enumerate(similarity_to_query): if i in idxs: continue - first_part = cosine_similarity(query_embedding, emb) - second_part = 0.0 - for j in idxs: - cos_sim = cosine_similarity(emb, embedding_list[j]) - if cos_sim > second_part: - second_part = cos_sim - equation_score = lambda_mult * first_part - (1 - lambda_mult) * second_part + redundant_score = max(similarity_to_selected[i]) + equation_score = ( + lambda_mult * query_score - (1 - lambda_mult) * redundant_score + ) if equation_score > best_score: best_score = equation_score idx_to_add = i idxs.append(idx_to_add) + selected = np.append(selected, [embedding_list[idx_to_add]], axis=0) return idxs diff --git a/tests/integration_tests/retrievers/__init__.py b/tests/integration_tests/retrievers/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/integration_tests/retrievers/document_compressors/__init__.py b/tests/integration_tests/retrievers/document_compressors/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/integration_tests/retrievers/document_compressors/test_base.py b/tests/integration_tests/retrievers/document_compressors/test_base.py new file mode 100644 index 0000000000000..389d4d04e7e39 --- /dev/null +++ b/tests/integration_tests/retrievers/document_compressors/test_base.py @@ -0,0 +1,28 @@ +"""Integration test for compression pipelines.""" +from langchain.document_transformers import EmbeddingsRedundantFilter +from langchain.embeddings import OpenAIEmbeddings +from langchain.retrievers.document_compressors import ( + DocumentCompressorPipeline, + EmbeddingsFilter, +) +from langchain.schema import Document +from langchain.text_splitter import CharacterTextSplitter + + +def test_document_compressor_pipeline() -> None: + embeddings = OpenAIEmbeddings() + splitter = CharacterTextSplitter(chunk_size=20, chunk_overlap=0, separator=". ") + redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings) + relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.8) + pipeline_filter = DocumentCompressorPipeline( + transformers=[splitter, redundant_filter, relevant_filter] + ) + texts = [ + "This sentence is about cows", + "This sentence was about cows", + "foo bar baz", + ] + docs = [Document(page_content=". ".join(texts))] + actual = pipeline_filter.compress_documents(docs, "Tell me about farm animals") + assert len(actual) == 1 + assert actual[0].page_content in texts[:2] diff --git a/tests/integration_tests/retrievers/document_compressors/test_chain_extract.py b/tests/integration_tests/retrievers/document_compressors/test_chain_extract.py new file mode 100644 index 0000000000000..0fcfebf9c60ab --- /dev/null +++ b/tests/integration_tests/retrievers/document_compressors/test_chain_extract.py @@ -0,0 +1,36 @@ +"""Integration test for LLMChainExtractor.""" +from langchain.chat_models import ChatOpenAI +from langchain.retrievers.document_compressors import LLMChainExtractor +from langchain.schema import Document + + +def test_llm_chain_extractor() -> None: + texts = [ + "The Roman Empire followed the Roman Republic.", + "I love chocolate chip cookies—my mother makes great cookies.", + "The first Roman emperor was Caesar Augustus.", + "Don't you just love Caesar salad?", + "The Roman Empire collapsed in 476 AD after the fall of Rome.", + "Let's go to Olive Garden!", + ] + doc = Document(page_content=" ".join(texts)) + compressor = LLMChainExtractor.from_llm(ChatOpenAI()) + actual = compressor.compress_documents([doc], "Tell me about the Roman Empire")[ + 0 + ].page_content + expected_returned = [0, 2, 4] + expected_not_returned = [1, 3, 5] + assert all([texts[i] in actual for i in expected_returned]) + assert all([texts[i] not in actual for i in expected_not_returned]) + + +def test_llm_chain_extractor_empty() -> None: + texts = [ + "I love chocolate chip cookies—my mother makes great cookies.", + "Don't you just love Caesar salad?", + "Let's go to Olive Garden!", + ] + doc = Document(page_content=" ".join(texts)) + compressor = LLMChainExtractor.from_llm(ChatOpenAI()) + actual = compressor.compress_documents([doc], "Tell me about the Roman Empire") + assert len(actual) == 0 diff --git a/tests/integration_tests/retrievers/document_compressors/test_chain_filter.py b/tests/integration_tests/retrievers/document_compressors/test_chain_filter.py new file mode 100644 index 0000000000000..1068a1e65a2b4 --- /dev/null +++ b/tests/integration_tests/retrievers/document_compressors/test_chain_filter.py @@ -0,0 +1,17 @@ +"""Integration test for llm-based relevant doc filtering.""" +from langchain.chat_models import ChatOpenAI +from langchain.retrievers.document_compressors import LLMChainFilter +from langchain.schema import Document + + +def test_llm_chain_filter() -> None: + texts = [ + "What happened to all of my cookies?", + "I wish there were better Italian restaurants in my neighborhood.", + "My favorite color is green", + ] + docs = [Document(page_content=t) for t in texts] + relevant_filter = LLMChainFilter.from_llm(llm=ChatOpenAI()) + actual = relevant_filter.compress_documents(docs, "Things I said related to food") + assert len(actual) == 2 + assert len(set(texts[:2]).intersection([d.page_content for d in actual])) == 2 diff --git a/tests/integration_tests/retrievers/document_compressors/test_embeddings_filter.py b/tests/integration_tests/retrievers/document_compressors/test_embeddings_filter.py new file mode 100644 index 0000000000000..15a13e39654b2 --- /dev/null +++ b/tests/integration_tests/retrievers/document_compressors/test_embeddings_filter.py @@ -0,0 +1,39 @@ +"""Integration test for embedding-based relevant doc filtering.""" +import numpy as np + +from langchain.document_transformers import _DocumentWithState +from langchain.embeddings import OpenAIEmbeddings +from langchain.retrievers.document_compressors import EmbeddingsFilter +from langchain.schema import Document + + +def test_embeddings_filter() -> None: + texts = [ + "What happened to all of my cookies?", + "I wish there were better Italian restaurants in my neighborhood.", + "My favorite color is green", + ] + docs = [Document(page_content=t) for t in texts] + embeddings = OpenAIEmbeddings() + relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.75) + actual = relevant_filter.compress_documents(docs, "What did I say about food?") + assert len(actual) == 2 + assert len(set(texts[:2]).intersection([d.page_content for d in actual])) == 2 + + +def test_embeddings_filter_with_state() -> None: + texts = [ + "What happened to all of my cookies?", + "I wish there were better Italian restaurants in my neighborhood.", + "My favorite color is green", + ] + query = "What did I say about food?" + embeddings = OpenAIEmbeddings() + embedded_query = embeddings.embed_query(query) + state = {"embedded_doc": np.zeros(len(embedded_query))} + docs = [_DocumentWithState(page_content=t, state=state) for t in texts] + docs[-1].state = {"embedded_doc": embedded_query} + relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.75) + actual = relevant_filter.compress_documents(docs, query) + assert len(actual) == 1 + assert texts[-1] == actual[0].page_content diff --git a/tests/integration_tests/retrievers/test_contextual_compression.py b/tests/integration_tests/retrievers/test_contextual_compression.py new file mode 100644 index 0000000000000..60eb206b8b9a2 --- /dev/null +++ b/tests/integration_tests/retrievers/test_contextual_compression.py @@ -0,0 +1,25 @@ +from langchain.embeddings import OpenAIEmbeddings +from langchain.retrievers.contextual_compression import ContextualCompressionRetriever +from langchain.retrievers.document_compressors import EmbeddingsFilter +from langchain.vectorstores import Chroma + + +def test_contextual_compression_retriever_get_relevant_docs() -> None: + """Test get_relevant_docs.""" + texts = [ + "This is a document about the Boston Celtics", + "The Boston Celtics won the game by 20 points", + "I simply love going to the movies", + ] + embeddings = OpenAIEmbeddings() + base_compressor = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.75) + base_retriever = Chroma.from_texts(texts, embedding=embeddings).as_retriever( + search_kwargs={"k": len(texts)} + ) + retriever = ContextualCompressionRetriever( + base_compressor=base_compressor, base_retriever=base_retriever + ) + + actual = retriever.get_relevant_documents("Tell me about the Celtics") + assert len(actual) == 2 + assert texts[-1] not in [d.page_content for d in actual] diff --git a/tests/integration_tests/test_document_transformers.py b/tests/integration_tests/test_document_transformers.py new file mode 100644 index 0000000000000..d5a23dba38e12 --- /dev/null +++ b/tests/integration_tests/test_document_transformers.py @@ -0,0 +1,31 @@ +"""Integration test for embedding-based redundant doc filtering.""" +from langchain.document_transformers import ( + EmbeddingsRedundantFilter, + _DocumentWithState, +) +from langchain.embeddings import OpenAIEmbeddings +from langchain.schema import Document + + +def test_embeddings_redundant_filter() -> None: + texts = [ + "What happened to all of my cookies?", + "Where did all of my cookies go?", + "I wish there were better Italian restaurants in my neighborhood.", + ] + docs = [Document(page_content=t) for t in texts] + embeddings = OpenAIEmbeddings() + redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings) + actual = redundant_filter.transform_documents(docs) + assert len(actual) == 2 + assert set(texts[:2]).intersection([d.page_content for d in actual]) + + +def test_embeddings_redundant_filter_with_state() -> None: + texts = ["What happened to all of my cookies?", "foo bar baz"] + state = {"embedded_doc": [0.5] * 10} + docs = [_DocumentWithState(page_content=t, state=state) for t in texts] + embeddings = OpenAIEmbeddings() + redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings) + actual = redundant_filter.transform_documents(docs) + assert len(actual) == 1 diff --git a/tests/unit_tests/test_document_transformers.py b/tests/unit_tests/test_document_transformers.py new file mode 100644 index 0000000000000..26354aeff8721 --- /dev/null +++ b/tests/unit_tests/test_document_transformers.py @@ -0,0 +1,15 @@ +"""Unit tests for document transformers.""" +from langchain.document_transformers import _filter_similar_embeddings +from langchain.math_utils import cosine_similarity + + +def test__filter_similar_embeddings() -> None: + threshold = 0.79 + embedded_docs = [[1.0, 2.0], [1.0, 2.0], [2.0, 1.0], [2.0, 0.5], [0.0, 0.0]] + expected = [1, 3, 4] + actual = _filter_similar_embeddings(embedded_docs, cosine_similarity, threshold) + assert expected == actual + + +def test__filter_similar_embeddings_empty() -> None: + assert len(_filter_similar_embeddings([], cosine_similarity, 0.0)) == 0 diff --git a/tests/unit_tests/test_math_utils.py b/tests/unit_tests/test_math_utils.py new file mode 100644 index 0000000000000..34b390a578a99 --- /dev/null +++ b/tests/unit_tests/test_math_utils.py @@ -0,0 +1,39 @@ +"""Test math utility functions.""" +from typing import List + +import numpy as np + +from langchain.math_utils import cosine_similarity + + +def test_cosine_similarity_zero() -> None: + X = np.zeros((3, 3)) + Y = np.random.random((3, 3)) + expected = np.zeros((3, 3)) + actual = cosine_similarity(X, Y) + assert np.allclose(expected, actual) + + +def test_cosine_similarity_identity() -> None: + X = np.random.random((4, 4)) + expected = np.ones(4) + actual = np.diag(cosine_similarity(X, X)) + assert np.allclose(expected, actual) + + +def test_cosine_similarity_empty() -> None: + empty_list: List[List[float]] = [] + assert len(cosine_similarity(empty_list, empty_list)) == 0 + assert len(cosine_similarity(empty_list, np.random.random((3, 3)))) == 0 + + +def test_cosine_similarity() -> None: + X = [[1.0, 2.0, 3.0], [0.0, 1.0, 0.0], [1.0, 2.0, 0.0]] + Y = [[0.5, 1.0, 1.5], [1.0, 0.0, 0.0], [2.0, 5.0, 2.0]] + expected = [ + [1.0, 0.26726124, 0.83743579], + [0.53452248, 0.0, 0.87038828], + [0.5976143, 0.4472136, 0.93419873], + ] + actual = cosine_similarity(X, Y) + assert np.allclose(expected, actual) From d7942a9f1922cea59e619216ac2511734eb26f95 Mon Sep 17 00:00:00 2001 From: Zach Jones Date: Thu, 20 Apr 2023 21:50:59 -0400 Subject: [PATCH 16/16] Fix type annotation for `QueryCheckerTool.llm` (#3237) Currently `langchain.tools.sql_database.tool.QueryCheckerTool` has a field `llm` with type `BaseLLM`. This breaks initialization for some LLMs. For example, trying to use it with GPT4: ```python from langchain.sql_database import SQLDatabase from langchain.chat_models import ChatOpenAI from langchain.tools.sql_database.tool import QueryCheckerTool db = SQLDatabase.from_uri("some_db_uri") llm = ChatOpenAI(model_name="gpt-4") tool = QueryCheckerTool(db=db, llm=llm) # pydantic.error_wrappers.ValidationError: 1 validation error for QueryCheckerTool # llm # Can't instantiate abstract class BaseLLM with abstract methods _agenerate, _generate, _llm_type (type=type_error) ``` Seems like much of the rest of the codebase has switched from `BaseLLM` to `BaseLanguageModel`. This PR makes the change for QueryCheckerTool as well Co-authored-by: Zachary Jones --- langchain/tools/sql_database/tool.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/langchain/tools/sql_database/tool.py b/langchain/tools/sql_database/tool.py index 3921b43a2fb6f..d9d6cf63e2112 100644 --- a/langchain/tools/sql_database/tool.py +++ b/langchain/tools/sql_database/tool.py @@ -6,7 +6,7 @@ from langchain.chains.llm import LLMChain from langchain.prompts import PromptTemplate from langchain.sql_database import SQLDatabase -from langchain.llms.base import BaseLLM +from langchain.schema import BaseLanguageModel from langchain.tools.base import BaseTool from langchain.tools.sql_database.prompt import QUERY_CHECKER @@ -81,7 +81,7 @@ class QueryCheckerTool(BaseSQLDatabaseTool, BaseTool): Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/""" template: str = QUERY_CHECKER - llm: BaseLLM + llm: BaseLanguageModel llm_chain: LLMChain = Field(init=False) name = "query_checker_sql_db" description = """