Skip to content

Improve vectorizer kwargs and typing #291

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions redisvl/extensions/llmcache/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,15 +310,17 @@ def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]:
if not isinstance(prompt, str):
raise TypeError("Prompt must be a string.")

return self._vectorizer.embed(prompt)
result = self._vectorizer.embed(prompt)
return result # type: ignore

async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]:
"""Converts a text prompt to its vector representation using the
configured vectorizer."""
if not isinstance(prompt, str):
raise TypeError("Prompt must be a string.")

return await self._vectorizer.aembed(prompt)
result = await self._vectorizer.aembed(prompt)
return result # type: ignore

def _check_vector_dims(self, vector: List[float]):
"""Checks the size of the provided vector and raises an error if it
Expand Down
8 changes: 4 additions & 4 deletions redisvl/extensions/router/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,14 +366,14 @@ def __call__(
if not vector:
if not statement:
raise ValueError("Must provide a vector or statement to the router")
vector = self.vectorizer.embed(statement)
vector = self.vectorizer.embed(statement) # type: ignore

aggregation_method = (
aggregation_method or self.routing_config.aggregation_method
)

# perform route classification
top_route_match = self._classify_route(vector, aggregation_method)
top_route_match = self._classify_route(vector, aggregation_method) # type: ignore
return top_route_match

@deprecated_argument("distance_threshold")
Expand All @@ -400,7 +400,7 @@ def route_many(
if not vector:
if not statement:
raise ValueError("Must provide a vector or statement to the router")
vector = self.vectorizer.embed(statement)
vector = self.vectorizer.embed(statement) # type: ignore

max_k = max_k or self.routing_config.max_k
aggregation_method = (
Expand All @@ -409,7 +409,7 @@ def route_many(

# classify routes
top_route_matches = self._classify_multi_route(
vector, max_k, aggregation_method
vector, max_k, aggregation_method # type: ignore
)

return top_route_matches
Expand Down
2 changes: 1 addition & 1 deletion redisvl/extensions/session_manager/semantic_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def add_messages(
role=message[ROLE_FIELD_NAME],
content=message[CONTENT_FIELD_NAME],
session_tag=session_tag,
vector_field=content_vector,
vector_field=content_vector, # type: ignore
)

if TOOL_FIELD_NAME in message:
Expand Down
68 changes: 57 additions & 11 deletions redisvl/utils/vectorize/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Union

from pydantic import BaseModel, Field, field_validator

Expand Down Expand Up @@ -49,34 +49,69 @@ def check_dims(cls, value):
return value

@abstractmethod
def embed_many(
def embed(
self,
texts: List[str],
text: str,
preprocess: Optional[Callable] = None,
batch_size: int = 1000,
as_buffer: bool = False,
**kwargs,
) -> List[List[float]]:
) -> Union[List[float], bytes]:
"""Embed a chunk of text.

Args:
text: Text to embed
preprocess: Optional function to preprocess text
as_buffer: If True, returns a bytes object instead of a list

Returns:
Union[List[float], bytes]: Embedding as a list of floats, or as a bytes
object if as_buffer=True
"""
raise NotImplementedError

@abstractmethod
def embed(
def embed_many(
self,
text: str,
texts: List[str],
preprocess: Optional[Callable] = None,
batch_size: int = 10,
as_buffer: bool = False,
**kwargs,
) -> List[float]:
) -> Union[List[List[float]], List[bytes]]:
"""Embed multiple chunks of text.

Args:
texts: List of texts to embed
preprocess: Optional function to preprocess text
batch_size: Number of texts to process in each batch
as_buffer: If True, returns each embedding as a bytes object

Returns:
Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats,
or as bytes objects if as_buffer=True
"""
raise NotImplementedError

async def aembed_many(
self,
texts: List[str],
preprocess: Optional[Callable] = None,
batch_size: int = 1000,
batch_size: int = 10,
as_buffer: bool = False,
**kwargs,
) -> List[List[float]]:
) -> Union[List[List[float]], List[bytes]]:
"""Asynchronously embed multiple chunks of text.

Args:
texts: List of texts to embed
preprocess: Optional function to preprocess text
batch_size: Number of texts to process in each batch
as_buffer: If True, returns each embedding as a bytes object

Returns:
Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats,
or as bytes objects if as_buffer=True
"""
# Fallback to standard embedding call if no async support
return self.embed_many(texts, preprocess, batch_size, as_buffer, **kwargs)

Expand All @@ -86,7 +121,18 @@ async def aembed(
preprocess: Optional[Callable] = None,
as_buffer: bool = False,
**kwargs,
) -> List[float]:
) -> Union[List[float], bytes]:
"""Asynchronously embed a chunk of text.

Args:
text: Text to embed
preprocess: Optional function to preprocess text
as_buffer: If True, returns a bytes object instead of a list

Returns:
Union[List[float], bytes]: Embedding as a list of floats, or as a bytes
object if as_buffer=True
"""
# Fallback to standard embedding call if no async support
return self.embed(text, preprocess, as_buffer, **kwargs)

Expand Down
38 changes: 24 additions & 14 deletions redisvl/utils/vectorize/text/azureopenai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Union

from pydantic import PrivateAttr
from tenacity import retry, stop_after_attempt, wait_random_exponential
Expand Down Expand Up @@ -178,7 +178,7 @@ def embed_many(
batch_size: int = 10,
as_buffer: bool = False,
**kwargs,
) -> List[List[float]]:
) -> Union[List[List[float]], List[bytes]]:
"""Embed many chunks of texts using the AzureOpenAI API.

Args:
Expand All @@ -191,7 +191,8 @@ def embed_many(
to a byte string. Defaults to False.

Returns:
List[List[float]]: List of embeddings.
Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats,
or as bytes objects if as_buffer=True

Raises:
TypeError: If the wrong input type is passed in for the test.
Expand All @@ -205,7 +206,9 @@ def embed_many(

embeddings: List = []
for batch in self.batchify(texts, batch_size, preprocess):
response = self._client.embeddings.create(input=batch, model=self.model)
response = self._client.embeddings.create(
input=batch, model=self.model, **kwargs
)
embeddings += [
self._process_embedding(r.embedding, as_buffer, dtype)
for r in response.data
Expand All @@ -224,7 +227,7 @@ def embed(
preprocess: Optional[Callable] = None,
as_buffer: bool = False,
**kwargs,
) -> List[float]:
) -> Union[List[float], bytes]:
"""Embed a chunk of text using the AzureOpenAI API.

Args:
Expand All @@ -235,7 +238,8 @@ def embed(
to a byte string. Defaults to False.

Returns:
List[float]: Embedding.
Union[List[float], bytes]: Embedding as a list of floats, or as a bytes
object if as_buffer=True

Raises:
TypeError: If the wrong input type is passed in for the test.
Expand All @@ -248,7 +252,9 @@ def embed(

dtype = kwargs.pop("dtype", self.dtype)

result = self._client.embeddings.create(input=[text], model=self.model)
result = self._client.embeddings.create(
input=[text], model=self.model, **kwargs
)
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)

@retry(
Expand All @@ -261,10 +267,10 @@ async def aembed_many(
self,
texts: List[str],
preprocess: Optional[Callable] = None,
batch_size: int = 1000,
batch_size: int = 10,
as_buffer: bool = False,
**kwargs,
) -> List[List[float]]:
) -> Union[List[List[float]], List[bytes]]:
"""Asynchronously embed many chunks of texts using the AzureOpenAI API.

Args:
Expand All @@ -277,7 +283,8 @@ async def aembed_many(
to a byte string. Defaults to False.

Returns:
List[List[float]]: List of embeddings.
Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats,
or as bytes objects if as_buffer=True

Raises:
TypeError: If the wrong input type is passed in for the test.
Expand All @@ -292,7 +299,7 @@ async def aembed_many(
embeddings: List = []
for batch in self.batchify(texts, batch_size, preprocess):
response = await self._aclient.embeddings.create(
input=batch, model=self.model
input=batch, model=self.model, **kwargs
)
embeddings += [
self._process_embedding(r.embedding, as_buffer, dtype)
Expand All @@ -312,7 +319,7 @@ async def aembed(
preprocess: Optional[Callable] = None,
as_buffer: bool = False,
**kwargs,
) -> List[float]:
) -> Union[List[float], bytes]:
"""Asynchronously embed a chunk of text using the OpenAI API.

Args:
Expand All @@ -323,7 +330,8 @@ async def aembed(
to a byte string. Defaults to False.

Returns:
List[float]: Embedding.
Union[List[float], bytes]: Embedding as a list of floats, or as a bytes
object if as_buffer=True

Raises:
TypeError: If the wrong input type is passed in for the test.
Expand All @@ -336,7 +344,9 @@ async def aembed(

dtype = kwargs.pop("dtype", self.dtype)

result = await self._aclient.embeddings.create(input=[text], model=self.model)
result = await self._aclient.embeddings.create(
input=[text], model=self.model, **kwargs
)
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)

@property
Expand Down
22 changes: 12 additions & 10 deletions redisvl/utils/vectorize/text/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import os
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Union

from pydantic import PrivateAttr
from tenacity import retry, stop_after_attempt, wait_random_exponential
Expand Down Expand Up @@ -135,16 +135,17 @@ def embed(
preprocess: Optional[Callable] = None,
as_buffer: bool = False,
**kwargs,
) -> List[float]:
"""Embed a chunk of text using Amazon Bedrock.
) -> Union[List[float], bytes]:
"""Embed a chunk of text using the AWS Bedrock Embeddings API.

Args:
text (str): Text to embed.
preprocess (Optional[Callable]): Optional preprocessing function.
as_buffer (bool): Whether to return as byte buffer.

Returns:
List[float]: The embedding vector.
Union[List[float], bytes]: Embedding as a list of floats, or as a bytes
object if as_buffer=True

Raises:
TypeError: If text is not a string.
Expand All @@ -156,7 +157,7 @@ def embed(
text = preprocess(text)

response = self._client.invoke_model(
modelId=self.model, body=json.dumps({"inputText": text})
modelId=self.model, body=json.dumps({"inputText": text}), **kwargs
)
response_body = json.loads(response["body"].read())
embedding = response_body["embedding"]
Expand All @@ -177,17 +178,18 @@ def embed_many(
batch_size: int = 10,
as_buffer: bool = False,
**kwargs,
) -> List[List[float]]:
"""Embed multiple texts using Amazon Bedrock.
) -> Union[List[List[float]], List[bytes]]:
"""Embed many chunks of text using the AWS Bedrock Embeddings API.

Args:
texts (List[str]): List of texts to embed.
preprocess (Optional[Callable]): Optional preprocessing function.
batch_size (int): Size of batches for processing.
batch_size (int): Size of batches for processing. Defaults to 10.
as_buffer (bool): Whether to return as byte buffers.

Returns:
List[List[float]]: List of embedding vectors.
Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats,
or as bytes objects if as_buffer=True

Raises:
TypeError: If texts is not a list of strings.
Expand All @@ -206,7 +208,7 @@ def embed_many(
batch_embeddings = []
for text in batch:
response = self._client.invoke_model(
modelId=self.model, body=json.dumps({"inputText": text})
modelId=self.model, body=json.dumps({"inputText": text}), **kwargs
)
response_body = json.loads(response["body"].read())
batch_embeddings.append(response_body["embedding"])
Expand Down
Loading