Skip to content

Commit

Permalink
🪨 feat: AWS Bedrock embeddings support (#75)
Browse files Browse the repository at this point in the history
* "WIP: adding bedrock embeddings"

* WIP: feat bedrock embeddings support

* feat: aws bedrock embeddings support

* refactor: update aws region var name

* docs: update env variables documentation for bedrock

* docs: add bedrock embeddings provider in list
  • Loading branch information
ScarFX authored Sep 27, 2024
1 parent edd8a0c commit f21b6e7
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 1 deletion.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,15 @@ The following environment variables are required to run the application:
- `PDF_EXTRACT_IMAGES`: (Optional) A boolean value indicating whether to extract images from PDF files. Default value is "False".
- `DEBUG_RAG_API`: (Optional) Set to "True" to show more verbose logging output in the server console, and to enable postgresql database routes
- `CONSOLE_JSON`: (Optional) Set to "True" to log as json for Cloud Logging aggregations
- `EMBEDDINGS_PROVIDER`: (Optional) either "openai", "azure", "huggingface", "huggingfacetei" or "ollama", where "huggingface" uses sentence_transformers; defaults to "openai"
- `EMBEDDINGS_PROVIDER`: (Optional) either "openai", "bedrock", "azure", "huggingface", "huggingfacetei" or "ollama", where "huggingface" uses sentence_transformers; defaults to "openai"
- `EMBEDDINGS_MODEL`: (Optional) Set a valid embeddings model to use from the configured provider.
- **Defaults**
- openai: "text-embedding-3-small"
- azure: "text-embedding-3-small" (will be used as your Azure Deployment)
- huggingface: "sentence-transformers/all-MiniLM-L6-v2"
- huggingfacetei: "http://huggingfacetei:3000". Hugging Face TEI uses model defined on TEI service launch.
- ollama: "nomic-embed-text"
- bedrock: "amazon.titan-embed-text-v1"
- `RAG_AZURE_OPENAI_API_VERSION`: (Optional) Default is `2023-05-15`. The version of the Azure OpenAI API.
- `RAG_AZURE_OPENAI_API_KEY`: (Optional) The API key for Azure OpenAI service.
- Note: `AZURE_OPENAI_API_KEY` will work but `RAG_AZURE_OPENAI_API_KEY` will override it in order to not conflict with LibreChat setting.
Expand All @@ -79,6 +80,9 @@ The following environment variables are required to run the application:
- `OLLAMA_BASE_URL`: (Optional) defaults to `http://ollama:11434`.
- `ATLAS_SEARCH_INDEX`: (Optional) the name of the vector search index if using Atlas MongoDB, defaults to `vector_index`
- `MONGO_VECTOR_COLLECTION`: Deprecated for MongoDB, please use `ATLAS_SEARCH_INDEX` and `COLLECTION_NAME`
- `AWS_DEFAULT_REGION`: (Optional) defaults to `us-east-1`
- `AWS_ACCESS_KEY_ID`: (Optional) needed for bedrock embeddings
- `AWS_SECRET_ACCESS_KEY`: (Optional) needed for bedrock embeddings

Make sure to set these environment variables before running the application. You can set them in a `.env` file or as system environment variables.

Expand Down
23 changes: 23 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import os
import json
import logging
import boto3
from enum import Enum
from datetime import datetime
from dotenv import find_dotenv, load_dotenv
from langchain_ollama import OllamaEmbeddings
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpointEmbeddings
from langchain_aws import BedrockEmbeddings
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
from starlette.middleware.base import BaseHTTPMiddleware
from store_factory import get_vector_store
Expand All @@ -25,6 +27,7 @@ class EmbeddingsProvider(Enum):
HUGGINGFACE = "huggingface"
HUGGINGFACETEI = "huggingfacetei"
OLLAMA = "ollama"
BEDROCK = "bedrock"


def get_env_variable(
Expand Down Expand Up @@ -168,6 +171,8 @@ async def dispatch(self, request, call_next):
).rstrip("/")
HF_TOKEN = get_env_variable("HF_TOKEN", "")
OLLAMA_BASE_URL = get_env_variable("OLLAMA_BASE_URL", "http://ollama:11434")
AWS_ACCESS_KEY_ID = get_env_variable("AWS_ACCESS_KEY_ID", "")
AWS_SECRET_ACCESS_KEY = get_env_variable("AWS_SECRET_ACCESS_KEY", "")

## Embeddings

Expand Down Expand Up @@ -195,6 +200,17 @@ def init_embeddings(provider, model):
return HuggingFaceEndpointEmbeddings(model=model)
elif provider == EmbeddingsProvider.OLLAMA:
return OllamaEmbeddings(model=model, base_url=OLLAMA_BASE_URL)
elif provider == EmbeddingsProvider.BEDROCK:
session = boto3.Session(
aws_access_key_id=AWS_ACCESS_KEY_ID,
aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
region_name=AWS_DEFAULT_REGION,
)
return BedrockEmbeddings(
client=session.client("bedrock-runtime"),
model_id=model,
region_name=AWS_DEFAULT_REGION,
)
else:
raise ValueError(f"Unsupported embeddings provider: {provider}")

Expand All @@ -217,6 +233,13 @@ def init_embeddings(provider, model):
)
elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.OLLAMA:
EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "nomic-embed-text")
elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.BEDROCK:
EMBEDDINGS_MODEL = get_env_variable(
"EMBEDDINGS_MODEL", "amazon.titan-embed-text-v1"
)
AWS_DEFAULT_REGION = get_env_variable(
"AWS_DEFAULT_REGION", "us-east-1"
)
else:
raise ValueError(f"Unsupported embeddings provider: {EMBEDDINGS_PROVIDER}")

Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ langchain==0.3
langchain_community==0.3
langchain_openai==0.2.0
langchain_core==0.3.5
langchain-aws==0.2.1
boto3==1.34.144
sqlalchemy==2.0.28
python-dotenv==1.0.1
fastapi==0.110.0
Expand Down

0 comments on commit f21b6e7

Please sign in to comment.