diff --git a/README.md b/README.md index 3146fe3..37bdcdf 100644 --- a/README.md +++ b/README.md @@ -61,7 +61,7 @@ The following environment variables are required to run the application: - `PDF_EXTRACT_IMAGES`: (Optional) A boolean value indicating whether to extract images from PDF files. Default value is "False". - `DEBUG_RAG_API`: (Optional) Set to "True" to show more verbose logging output in the server console, and to enable postgresql database routes - `CONSOLE_JSON`: (Optional) Set to "True" to log as json for Cloud Logging aggregations -- `EMBEDDINGS_PROVIDER`: (Optional) either "openai", "azure", "huggingface", "huggingfacetei" or "ollama", where "huggingface" uses sentence_transformers; defaults to "openai" +- `EMBEDDINGS_PROVIDER`: (Optional) either "openai", "bedrock", "azure", "huggingface", "huggingfacetei" or "ollama", where "huggingface" uses sentence_transformers; defaults to "openai" - `EMBEDDINGS_MODEL`: (Optional) Set a valid embeddings model to use from the configured provider. - **Defaults** - openai: "text-embedding-3-small" @@ -69,6 +69,7 @@ The following environment variables are required to run the application: - huggingface: "sentence-transformers/all-MiniLM-L6-v2" - huggingfacetei: "http://huggingfacetei:3000". Hugging Face TEI uses model defined on TEI service launch. - ollama: "nomic-embed-text" + - bedrock: "amazon.titan-embed-text-v1" - `RAG_AZURE_OPENAI_API_VERSION`: (Optional) Default is `2023-05-15`. The version of the Azure OpenAI API. - `RAG_AZURE_OPENAI_API_KEY`: (Optional) The API key for Azure OpenAI service. - Note: `AZURE_OPENAI_API_KEY` will work but `RAG_AZURE_OPENAI_API_KEY` will override it in order to not conflict with LibreChat setting. @@ -79,6 +80,9 @@ The following environment variables are required to run the application: - `OLLAMA_BASE_URL`: (Optional) defaults to `http://ollama:11434`. - `ATLAS_SEARCH_INDEX`: (Optional) the name of the vector search index if using Atlas MongoDB, defaults to `vector_index` - `MONGO_VECTOR_COLLECTION`: Deprecated for MongoDB, please use `ATLAS_SEARCH_INDEX` and `COLLECTION_NAME` +- `AWS_DEFAULT_REGION`: (Optional) defaults to `us-east-1` +- `AWS_ACCESS_KEY_ID`: (Optional) needed for bedrock embeddings +- `AWS_SECRET_ACCESS_KEY`: (Optional) needed for bedrock embeddings Make sure to set these environment variables before running the application. You can set them in a `.env` file or as system environment variables. diff --git a/config.py b/config.py index 36f9202..d8d65ba 100644 --- a/config.py +++ b/config.py @@ -2,11 +2,13 @@ import os import json import logging +import boto3 from enum import Enum from datetime import datetime from dotenv import find_dotenv, load_dotenv from langchain_ollama import OllamaEmbeddings from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpointEmbeddings +from langchain_aws import BedrockEmbeddings from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings from starlette.middleware.base import BaseHTTPMiddleware from store_factory import get_vector_store @@ -25,6 +27,7 @@ class EmbeddingsProvider(Enum): HUGGINGFACE = "huggingface" HUGGINGFACETEI = "huggingfacetei" OLLAMA = "ollama" + BEDROCK = "bedrock" def get_env_variable( @@ -168,6 +171,8 @@ async def dispatch(self, request, call_next): ).rstrip("/") HF_TOKEN = get_env_variable("HF_TOKEN", "") OLLAMA_BASE_URL = get_env_variable("OLLAMA_BASE_URL", "http://ollama:11434") +AWS_ACCESS_KEY_ID = get_env_variable("AWS_ACCESS_KEY_ID", "") +AWS_SECRET_ACCESS_KEY = get_env_variable("AWS_SECRET_ACCESS_KEY", "") ## Embeddings @@ -195,6 +200,17 @@ def init_embeddings(provider, model): return HuggingFaceEndpointEmbeddings(model=model) elif provider == EmbeddingsProvider.OLLAMA: return OllamaEmbeddings(model=model, base_url=OLLAMA_BASE_URL) + elif provider == EmbeddingsProvider.BEDROCK: + session = boto3.Session( + aws_access_key_id=AWS_ACCESS_KEY_ID, + aws_secret_access_key=AWS_SECRET_ACCESS_KEY, + region_name=AWS_DEFAULT_REGION, + ) + return BedrockEmbeddings( + client=session.client("bedrock-runtime"), + model_id=model, + region_name=AWS_DEFAULT_REGION, + ) else: raise ValueError(f"Unsupported embeddings provider: {provider}") @@ -217,6 +233,13 @@ def init_embeddings(provider, model): ) elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.OLLAMA: EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "nomic-embed-text") +elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.BEDROCK: + EMBEDDINGS_MODEL = get_env_variable( + "EMBEDDINGS_MODEL", "amazon.titan-embed-text-v1" + ) + AWS_DEFAULT_REGION = get_env_variable( + "AWS_DEFAULT_REGION", "us-east-1" + ) else: raise ValueError(f"Unsupported embeddings provider: {EMBEDDINGS_PROVIDER}") diff --git a/requirements.txt b/requirements.txt index f015b94..28f18a0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,8 @@ langchain==0.3 langchain_community==0.3 langchain_openai==0.2.0 langchain_core==0.3.5 +langchain-aws==0.2.1 +boto3==1.34.144 sqlalchemy==2.0.28 python-dotenv==1.0.1 fastapi==0.110.0