Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community: Bulk LocalAIEmbeddings #22666

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 68 additions & 55 deletions libs/community/langchain_community/embeddings/localai.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ def _create_retry_decorator(embeddings: LocalAIEmbeddings) -> Callable[[Any], An
stop=stop_after_attempt(embeddings.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError)
retry_if_exception_type(openai.APITimeoutError)
| retry_if_exception_type(openai.APIError)
| retry_if_exception_type(openai.APIConnectionError)
| retry_if_exception_type(openai.RateLimitError)
| retry_if_exception_type(openai.InternalServerError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
Expand All @@ -64,11 +64,11 @@ def _async_retry_decorator(embeddings: LocalAIEmbeddings) -> Any:
stop=stop_after_attempt(embeddings.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError)
retry_if_exception_type(openai.APITimeoutError)
| retry_if_exception_type(openai.APIError)
| retry_if_exception_type(openai.APIConnectionError)
| retry_if_exception_type(openai.RateLimitError)
| retry_if_exception_type(openai.InternalServerError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
Expand All @@ -86,10 +86,10 @@ async def wrapped_f(*args: Any, **kwargs: Any) -> Callable:

# https://stackoverflow.com/questions/76469415/getting-embeddings-of-length-1-from-langchain-openaiembeddings
def _check_response(response: dict) -> dict:
if any(len(d["embedding"]) == 1 for d in response["data"]):
if any(len(d.embedding) == 1 for d in response.data):
import openai

raise openai.error.APIError("LocalAI API returned an empty embedding")
raise openai.APIError("LocalAI API returned an empty embedding")
return response


Expand All @@ -110,7 +110,7 @@ async def async_embed_with_retry(embeddings: LocalAIEmbeddings, **kwargs: Any) -

@_async_retry_decorator(embeddings)
async def _async_embed_with_retry(**kwargs: Any) -> Any:
response = await embeddings.client.acreate(**kwargs)
response = await embeddings.async_client.create(**kwargs)
return _check_response(response)

return await _async_embed_with_retry(**kwargs)
Expand All @@ -135,15 +135,30 @@ class LocalAIEmbeddings(BaseModel, Embeddings):
openai_api_base="http://localhost:8080"
)

Specifying proxy:
.. code-block:: python

from langchain_community.embeddings import LocalAIEmbeddings
import openai
import httpx
openai = LocalAIEmbeddings(
openai_api_key="random-string",
client=openai.OpenAI(
base_url="http://localhost:8080",
http_client=openai.DefaultHttpxClient(
proxies="http://localhost:8899",
transport=httpx.HTTPTransport(local_address="0.0.0.0"),
),
api_key="random-string").embeddings
)
"""

client: Any #: :meta private:
async_client: Any #: :meta private:
model: str = "text-embedding-ada-002"
deployment: str = model
openai_api_version: Optional[str] = None
openai_api_base: Optional[str] = None
# to support explicit proxy for LocalAI
openai_proxy: Optional[str] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why delete this?

embedding_ctx_length: int = 8191
"""The maximum number of tokens to embed at once."""
openai_api_key: Optional[str] = None
Expand Down Expand Up @@ -205,12 +220,6 @@ def validate_environment(cls, values: Dict) -> Dict:
"OPENAI_API_BASE",
default="",
)
values["openai_proxy"] = get_from_dict_or_env(
values,
"openai_proxy",
"OPENAI_PROXY",
default="",
)

default_api_version = ""
values["openai_api_version"] = get_from_dict_or_env(
Expand All @@ -228,7 +237,17 @@ def validate_environment(cls, values: Dict) -> Dict:
try:
import openai

values["client"] = openai.Embedding
client_params = {
"api_key": values["openai_api_key"],
"organization": values["openai_organization"],
"base_url": values["openai_api_base"],
"timeout": values["request_timeout"],
"max_retries": values["max_retries"],
}
if not values.get("client"):
values["client"] = openai.OpenAI(**client_params).embeddings
if not values.get("async_client"):
values["async_client"] = openai.AsyncOpenAI(**client_params).embeddings
except ImportError:
raise ImportError(
"Could not import openai python package. "
Expand All @@ -240,50 +259,49 @@ def validate_environment(cls, values: Dict) -> Dict:
def _invocation_params(self) -> Dict:
openai_args = {
"model": self.model,
"request_timeout": self.request_timeout,
"headers": self.headers,
"api_key": self.openai_api_key,
"organization": self.openai_organization,
"api_base": self.openai_api_base,
"api_version": self.openai_api_version,
**self.model_kwargs,
}
if self.openai_proxy:
import openai

openai.proxy = {
"http": self.openai_proxy,
"https": self.openai_proxy,
} # type: ignore[assignment]
return openai_args

def _embedding_func(self, text: str, *, engine: str) -> List[float]:
def _embedding_func(
self, text: str | list[str], *, engine: str
) -> List[List[float]]:
"""Call out to LocalAI's embedding endpoint."""
# handle large input text
if self.model.endswith("001"):
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
return embed_with_retry(
if isinstance(text, str):
text = text.replace("\n", " ")
else:
text = [t.replace("\n", " ") for t in text]
listofembdes = embed_with_retry(
self,
input=[text],
input=[text] if isinstance(text, str) else text,
**self._invocation_params,
)["data"][0]["embedding"]
).data
return [d.embedding for d in listofembdes]

async def _aembedding_func(self, text: str, *, engine: str) -> List[float]:
async def _aembedding_func(
self, text: str | List[str], *, engine: str
) -> List[List[float]]:
"""Call out to LocalAI's embedding endpoint."""
# handle large input text
if self.model.endswith("001"):
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
return (
if isinstance(text, str):
text = text.replace("\n", " ")
else:
text = [t.replace("\n", " ") for t in text]
listofembdes = (
await async_embed_with_retry(
self,
input=[text],
input=[text] if isinstance(text, str) else text,
**self._invocation_params,
)
)["data"][0]["embedding"]
).data
return [d.embedding for d in listofembdes]

def embed_documents(
self, texts: List[str], chunk_size: Optional[int] = 0
Expand All @@ -299,7 +317,7 @@ def embed_documents(
List of embeddings, one for each text.
"""
# call _embedding_func for each text
return [self._embedding_func(text, engine=self.deployment) for text in texts]
return self._embedding_func(texts, engine=self.deployment)

async def aembed_documents(
self, texts: List[str], chunk_size: Optional[int] = 0
Expand All @@ -314,11 +332,7 @@ async def aembed_documents(
Returns:
List of embeddings, one for each text.
"""
embeddings = []
for text in texts:
response = await self._aembedding_func(text, engine=self.deployment)
embeddings.append(response)
return embeddings
return await self._aembedding_func(texts, engine=self.deployment)

def embed_query(self, text: str) -> List[float]:
"""Call out to LocalAI's embedding endpoint for embedding query text.
Expand All @@ -329,8 +343,7 @@ def embed_query(self, text: str) -> List[float]:
Returns:
Embedding for the text.
"""
embedding = self._embedding_func(text, engine=self.deployment)
return embedding
return self._embedding_func([text], engine=self.deployment)[0]

async def aembed_query(self, text: str) -> List[float]:
"""Call out to LocalAI's embedding endpoint async for embedding query text.
Expand All @@ -341,5 +354,5 @@ async def aembed_query(self, text: str) -> List[float]:
Returns:
Embedding for the text.
"""
embedding = await self._aembedding_func(text, engine=self.deployment)
return embedding
embeddings = await self._aembedding_func([text], engine=self.deployment)
return embeddings[0]
Loading
Loading