diff --git a/docs/api/schema.rst b/docs/api/schema.rst index ebe4ca8a..36245ba4 100644 --- a/docs/api/schema.rst +++ b/docs/api/schema.rst @@ -88,7 +88,7 @@ Each field type supports specific attributes that customize its behavior. Below - `dims`: Dimensionality of the vector. - `algorithm`: Indexing algorithm (`flat` or `hnsw`). -- `datatype`: Float datatype of the vector (`float32` or `float64`). +- `datatype`: Float datatype of the vector (`bfloat16`, `float16`, `float32`, `float64`). - `distance_metric`: Metric for measuring query relevance (`COSINE`, `L2`, `IP`). **HNSW Vector Field Specific Attributes**: diff --git a/docs/examples/openai_qna.ipynb b/docs/examples/openai_qna.ipynb index 614ed3e3..a1e59490 100644 --- a/docs/examples/openai_qna.ipynb +++ b/docs/examples/openai_qna.ipynb @@ -579,7 +579,7 @@ "api_key = os.getenv(\"OPENAI_API_KEY\") or getpass.getpass(\"Enter your OpenAI API key: \")\n", "oaip = OpenAITextVectorizer(EMBEDDINGS_MODEL, api_config={\"api_key\": api_key})\n", "\n", - "chunked_data[\"embedding\"] = oaip.embed_many(chunked_data[\"content\"].tolist(), as_buffer=True)\n", + "chunked_data[\"embedding\"] = oaip.embed_many(chunked_data[\"content\"].tolist(), as_buffer=True, dtype=\"float32\")\n", "chunked_data" ] }, @@ -1073,7 +1073,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.12" + "version": "3.12.2" }, "orig_nbformat": 4 }, diff --git a/docs/user_guide/hash_vs_json_05.ipynb b/docs/user_guide/hash_vs_json_05.ipynb index a046a963..9cb0092b 100644 --- a/docs/user_guide/hash_vs_json_05.ipynb +++ b/docs/user_guide/hash_vs_json_05.ipynb @@ -429,7 +429,7 @@ "json_data = data.copy()\n", "\n", "for d in json_data:\n", - " d['user_embedding'] = buffer_to_array(d['user_embedding'], dtype=np.float32)" + " d['user_embedding'] = buffer_to_array(d['user_embedding'], dtype='float32')" ] }, { diff --git a/docs/user_guide/vectorizers_04.ipynb b/docs/user_guide/vectorizers_04.ipynb index f9bc9b82..90b05892 100644 --- a/docs/user_guide/vectorizers_04.ipynb +++ b/docs/user_guide/vectorizers_04.ipynb @@ -356,7 +356,7 @@ "outputs": [], "source": [ "# You can also create many embeddings at once\n", - "embeddings = hf.embed_many(sentences, as_buffer=True)\n" + "embeddings = hf.embed_many(sentences, as_buffer=True, dtype=\"float32\")\n" ] }, { @@ -569,7 +569,7 @@ "source": [ "from redisvl.utils.vectorize import CustomTextVectorizer\n", "\n", - "def generate_embeddings(text_input):\n", + "def generate_embeddings(text_input, **kwargs):\n", " return [0.101] * 768\n", "\n", "custom_vectorizer = CustomTextVectorizer(generate_embeddings)\n", diff --git a/poetry.lock b/poetry.lock index 24c3048f..bfeb29d5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1003,12 +1003,12 @@ files = [ google-auth = ">=2.14.1,<3.0.dev0" googleapis-common-protos = ">=1.56.2,<2.0.dev0" grpcio = [ - {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, + {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, ] grpcio-status = [ - {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, + {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, ] proto-plus = ">=1.22.3,<2.0.0dev" protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0" @@ -2351,6 +2351,43 @@ files = [ intel-openmp = "==2021.*" tbb = "==2021.*" +[[package]] +name = "ml-dtypes" +version = "0.4.1" +description = "" +optional = false +python-versions = ">=3.9" +files = [ + {file = "ml_dtypes-0.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:1fe8b5b5e70cd67211db94b05cfd58dace592f24489b038dc6f9fe347d2e07d5"}, + {file = "ml_dtypes-0.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c09a6d11d8475c2a9fd2bc0695628aec105f97cab3b3a3fb7c9660348ff7d24"}, + {file = "ml_dtypes-0.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9f5e8f75fa371020dd30f9196e7d73babae2abd51cf59bdd56cb4f8de7e13354"}, + {file = "ml_dtypes-0.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:15fdd922fea57e493844e5abb930b9c0bd0af217d9edd3724479fc3d7ce70e3f"}, + {file = "ml_dtypes-0.4.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:2d55b588116a7085d6e074cf0cdb1d6fa3875c059dddc4d2c94a4cc81c23e975"}, + {file = "ml_dtypes-0.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e138a9b7a48079c900ea969341a5754019a1ad17ae27ee330f7ebf43f23877f9"}, + {file = "ml_dtypes-0.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:74c6cfb5cf78535b103fde9ea3ded8e9f16f75bc07789054edc7776abfb3d752"}, + {file = "ml_dtypes-0.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:274cc7193dd73b35fb26bef6c5d40ae3eb258359ee71cd82f6e96a8c948bdaa6"}, + {file = "ml_dtypes-0.4.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:827d3ca2097085cf0355f8fdf092b888890bb1b1455f52801a2d7756f056f54b"}, + {file = "ml_dtypes-0.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:772426b08a6172a891274d581ce58ea2789cc8abc1c002a27223f314aaf894e7"}, + {file = "ml_dtypes-0.4.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:126e7d679b8676d1a958f2651949fbfa182832c3cd08020d8facd94e4114f3e9"}, + {file = "ml_dtypes-0.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:df0fb650d5c582a9e72bb5bd96cfebb2cdb889d89daff621c8fbc60295eba66c"}, + {file = "ml_dtypes-0.4.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:e35e486e97aee577d0890bc3bd9e9f9eece50c08c163304008587ec8cfe7575b"}, + {file = "ml_dtypes-0.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:560be16dc1e3bdf7c087eb727e2cf9c0e6a3d87e9f415079d2491cc419b3ebf5"}, + {file = "ml_dtypes-0.4.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ad0b757d445a20df39035c4cdeed457ec8b60d236020d2560dbc25887533cf50"}, + {file = "ml_dtypes-0.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:ef0d7e3fece227b49b544fa69e50e607ac20948f0043e9f76b44f35f229ea450"}, + {file = "ml_dtypes-0.4.1.tar.gz", hash = "sha256:fad5f2de464fd09127e49b7fd1252b9006fb43d2edc1ff112d390c324af5ca7a"}, +] + +[package.dependencies] +numpy = [ + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, + {version = ">=1.23.3", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, + {version = ">=1.21.2", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, + {version = ">1.20", markers = "python_version < \"3.10\""}, +] + +[package.extras] +dev = ["absl-py", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-xdist"] + [[package]] name = "mpmath" version = "1.3.0" @@ -3426,9 +3463,9 @@ files = [ astroid = ">=3.1.0,<=3.2.0-dev0" colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} dill = [ - {version = ">=0.2", markers = "python_version < \"3.11\""}, {version = ">=0.3.7", markers = "python_version >= \"3.12\""}, {version = ">=0.3.6", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, + {version = ">=0.2", markers = "python_version < \"3.11\""}, ] isort = ">=4.2.5,<5.13.0 || >5.13.0,<6" mccabe = ">=0.6,<0.8" @@ -5523,4 +5560,4 @@ sentence-transformers = ["sentence-transformers"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "be9b5df2ff3600823749e4d0bfffe148c6bb04f88fa287a3dfae712ade9fd06e" +content-hash = "4dbfe0e66ba3b90c5cb8746034ec5e870e07eb8207c5ba95ac700939c91ac89d" diff --git a/pyproject.toml b/pyproject.toml index 6ecad04e..027652d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ redis = ">=5.0.0" pydantic = { version = ">=2,<3" } tenacity = ">=8.2.2" tabulate = { version = ">=0.9.0,<1" } +ml-dtypes = "^0.4.0" openai = { version = ">=1.13.0", optional = true } sentence-transformers = { version = ">=2.2.2", optional = true } google-cloud-aiplatform = { version = ">=1.26", optional = true } diff --git a/redisvl/extensions/llmcache/schema.py b/redisvl/extensions/llmcache/schema.py index fe970252..ad72d749 100644 --- a/redisvl/extensions/llmcache/schema.py +++ b/redisvl/extensions/llmcache/schema.py @@ -48,9 +48,9 @@ def non_empty_metadata(cls, v): raise TypeError("Metadata must be a dictionary.") return v - def to_dict(self) -> Dict: + def to_dict(self, dtype: str) -> Dict: data = self.dict(exclude_none=True) - data["prompt_vector"] = array_to_buffer(self.prompt_vector) + data["prompt_vector"] = array_to_buffer(self.prompt_vector, dtype) if self.metadata is not None: data["metadata"] = serialize(self.metadata) if self.filters is not None: @@ -112,7 +112,7 @@ def to_dict(self) -> Dict: class SemanticCacheIndexSchema(IndexSchema): @classmethod - def from_params(cls, name: str, prefix: str, vector_dims: int): + def from_params(cls, name: str, prefix: str, vector_dims: int, dtype: str): return cls( index={"name": name, "prefix": prefix}, # type: ignore @@ -126,7 +126,7 @@ def from_params(cls, name: str, prefix: str, vector_dims: int): "type": "vector", "attrs": { "dims": vector_dims, - "datatype": "float32", + "datatype": dtype, "distance_metric": "cosine", "algorithm": "flat", }, diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index d284b602..4fb01cff 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -104,7 +104,10 @@ def __init__( ] # Create semantic cache schema and index - schema = SemanticCacheIndexSchema.from_params(name, prefix, vectorizer.dims) + dtype = kwargs.get("dtype", "float32") + schema = SemanticCacheIndexSchema.from_params( + name, prefix, vectorizer.dims, dtype + ) schema = self._modify_schema(schema, filterable_fields) self._index = SearchIndex(schema=schema) @@ -137,6 +140,7 @@ def __init__( self._index.schema.fields[CACHE_VECTOR_FIELD_NAME].attrs.dims, # type: ignore ) self._vectorizer = vectorizer + self._dtype = self.index.schema.fields[CACHE_VECTOR_FIELD_NAME].attrs.datatype # type: ignore[union-attr] def _modify_schema( self, @@ -286,7 +290,7 @@ def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]: if not isinstance(prompt, str): raise TypeError("Prompt must be a string.") - return self._vectorizer.embed(prompt) + return self._vectorizer.embed(prompt, dtype=self._dtype) async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]: """Converts a text prompt to its vector representation using the @@ -368,6 +372,7 @@ def check( num_results=num_results, return_score=True, filter_expression=filter_expression, + dtype=self._dtype, ) # Search the cache! @@ -538,7 +543,7 @@ def store( # Load cache entry with TTL ttl = ttl or self._ttl keys = self._index.load( - data=[cache_entry.to_dict()], + data=[cache_entry.to_dict(self._dtype)], ttl=ttl, id_field=ENTRY_ID_FIELD_NAME, ) @@ -602,7 +607,7 @@ async def astore( # Load cache entry with TTL ttl = ttl or self._ttl keys = await aindex.load( - data=[cache_entry.to_dict()], + data=[cache_entry.to_dict(self._dtype)], ttl=ttl, id_field=ENTRY_ID_FIELD_NAME, ) diff --git a/redisvl/extensions/router/schema.py b/redisvl/extensions/router/schema.py index 04272e2a..6de61fb0 100644 --- a/redisvl/extensions/router/schema.py +++ b/redisvl/extensions/router/schema.py @@ -4,7 +4,7 @@ from pydantic.v1 import BaseModel, Field, validator from redisvl.extensions.constants import ROUTE_VECTOR_FIELD_NAME -from redisvl.schema import IndexInfo, IndexSchema +from redisvl.schema import IndexSchema class Route(BaseModel): @@ -89,7 +89,7 @@ class SemanticRouterIndexSchema(IndexSchema): """Customized index schema for SemanticRouter.""" @classmethod - def from_params(cls, name: str, vector_dims: int) -> "SemanticRouterIndexSchema": + def from_params(cls, name: str, vector_dims: int, dtype: str): """Create an index schema based on router name and vector dimensions. Args: @@ -100,7 +100,7 @@ def from_params(cls, name: str, vector_dims: int) -> "SemanticRouterIndexSchema" SemanticRouterIndexSchema: The constructed index schema. """ return cls( - index=IndexInfo(name=name, prefix=name), + index={"name": name, "prefix": name}, # type: ignore fields=[ # type: ignore {"name": "route_name", "type": "tag"}, {"name": "reference", "type": "text"}, @@ -111,7 +111,7 @@ def from_params(cls, name: str, vector_dims: int) -> "SemanticRouterIndexSchema" "algorithm": "flat", "dims": vector_dims, "distance_metric": "cosine", - "datatype": "float32", + "datatype": dtype, }, }, ], diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py index f1e235aa..52ea321f 100644 --- a/redisvl/extensions/router/semantic.py +++ b/redisvl/extensions/router/semantic.py @@ -85,17 +85,23 @@ def __init__( vectorizer=vectorizer, routing_config=routing_config, ) - self._initialize_index(redis_client, redis_url, overwrite, **connection_kwargs) + dtype = kwargs.get("dtype", "float32") + self._initialize_index( + redis_client, redis_url, overwrite, dtype, **connection_kwargs + ) def _initialize_index( self, redis_client: Optional[Redis] = None, redis_url: str = "redis://localhost:6379", overwrite: bool = False, + dtype: str = "float32", **connection_kwargs, ): """Initialize the search index and handle Redis connection.""" - schema = SemanticRouterIndexSchema.from_params(self.name, self.vectorizer.dims) + schema = SemanticRouterIndexSchema.from_params( + self.name, self.vectorizer.dims, dtype + ) self._index = SearchIndex(schema=schema) if redis_client: @@ -103,8 +109,18 @@ def _initialize_index( elif redis_url: self._index.connect(redis_url=redis_url, **connection_kwargs) + # Check for existing router index existed = self._index.exists() - self._index.create(overwrite=overwrite) + if not overwrite and existed: + existing_index = SearchIndex.from_existing( + self.name, redis_client=self._index.client + ) + if existing_index.schema != self._index.schema: + raise ValueError( + f"Existing index {self.name} schema does not match the user provided schema for the semantic router. " + "If you wish to overwrite the index schema, set overwrite=True during initialization." + ) + self._index.create(overwrite=overwrite, drop=False) if not existed or overwrite: # write the routes to Redis @@ -153,7 +169,9 @@ def _add_routes(self, routes: List[Route]): for route in routes: # embed route references as a single batch reference_vectors = self.vectorizer.embed_many( - [reference for reference in route.references], as_buffer=True + [reference for reference in route.references], + as_buffer=True, + dtype=self._index.schema.fields[ROUTE_VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr] ) # set route references for i, reference in enumerate(route.references): @@ -230,6 +248,7 @@ def _classify_route( vector_field_name=ROUTE_VECTOR_FIELD_NAME, distance_threshold=distance_threshold, return_fields=["route_name"], + dtype=self._index.schema.fields[ROUTE_VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr] ) aggregate_request = self._build_aggregate_request( @@ -282,6 +301,7 @@ def _classify_multi_route( vector_field_name=ROUTE_VECTOR_FIELD_NAME, distance_threshold=distance_threshold, return_fields=["route_name"], + dtype=self._index.schema.fields[ROUTE_VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr] ) aggregate_request = self._build_aggregate_request( vector_range_query, aggregation_method, max_k diff --git a/redisvl/extensions/session_manager/schema.py b/redisvl/extensions/session_manager/schema.py index aabc67b4..f943889d 100644 --- a/redisvl/extensions/session_manager/schema.py +++ b/redisvl/extensions/session_manager/schema.py @@ -48,15 +48,14 @@ def generate_id(cls, values): ) return values - def to_dict(self) -> Dict: + def to_dict(self, dtype: Optional[str] = None) -> Dict: data = self.dict(exclude_none=True) # handle optional fields if SESSION_VECTOR_FIELD_NAME in data: data[SESSION_VECTOR_FIELD_NAME] = array_to_buffer( - data[SESSION_VECTOR_FIELD_NAME] + data[SESSION_VECTOR_FIELD_NAME], dtype # type: ignore[arg-type] ) - return data @@ -80,7 +79,7 @@ def from_params(cls, name: str, prefix: str): class SemanticSessionIndexSchema(IndexSchema): @classmethod - def from_params(cls, name: str, prefix: str, vectorizer_dims: int): + def from_params(cls, name: str, prefix: str, vectorizer_dims: int, dtype: str): return cls( index={"name": name, "prefix": prefix}, # type: ignore @@ -95,7 +94,7 @@ def from_params(cls, name: str, prefix: str, vectorizer_dims: int): "type": "vector", "attrs": { "dims": vectorizer_dims, - "datatype": "float32", + "datatype": dtype, "distance_metric": "cosine", "algorithm": "flat", }, diff --git a/redisvl/extensions/session_manager/semantic_session.py b/redisvl/extensions/session_manager/semantic_session.py index 29474904..ce6c8f0d 100644 --- a/redisvl/extensions/session_manager/semantic_session.py +++ b/redisvl/extensions/session_manager/semantic_session.py @@ -24,7 +24,6 @@ class SemanticSessionManager(BaseSessionManager): - vector_field_name: str = "vector_field" def __init__( self, @@ -36,6 +35,7 @@ def __init__( redis_client: Optional[Redis] = None, redis_url: str = "redis://localhost:6379", connection_kwargs: Dict[str, Any] = {}, + overwrite: bool = False, **kwargs, ): """Initialize session memory with index @@ -60,6 +60,8 @@ def __init__( redis_url (str, optional): The redis url. Defaults to redis://localhost:6379. connection_kwargs (Dict[str, Any]): The connection arguments for the redis client. Defaults to empty {}. + overwrite (bool): Whether or not to force overwrite the schema for + the semantic session index. Defaults to false. The proposed schema will support a single vector embedding constructed from either the prompt or response in a single string. @@ -75,8 +77,9 @@ def __init__( self.set_distance_threshold(distance_threshold) + dtype = kwargs.get("dtype", "float32") schema = SemanticSessionIndexSchema.from_params( - name, prefix, self._vectorizer.dims + name, prefix, self._vectorizer.dims, dtype ) self._index = SearchIndex(schema=schema) @@ -87,7 +90,17 @@ def __init__( elif redis_url: self._index.connect(redis_url=redis_url, **connection_kwargs) - self._index.create(overwrite=False) + # Check for existing session index + if not overwrite and self._index.exists(): + existing_index = SearchIndex.from_existing( + name, redis_client=self._index.client + ) + if existing_index.schema != self._index.schema: + raise ValueError( + f"Existing index {name} schema does not match the user provided schema for the semantic session. " + "If you wish to overwrite the index schema, set overwrite=True during initialization." + ) + self._index.create(overwrite=overwrite, drop=False) self._default_session_filter = Tag(SESSION_FIELD_NAME) == self._session_tag @@ -202,6 +215,7 @@ def get_relevant( num_results=top_k, return_score=True, filter_expression=session_filter, + dtype=self._index.schema.fields[SESSION_VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr] ) messages = self._index.query(query) @@ -327,7 +341,7 @@ def add_messages( if TOOL_FIELD_NAME in message: chat_message.tool_call_id = message[TOOL_FIELD_NAME] - chat_messages.append(chat_message.to_dict()) + chat_messages.append(chat_message.to_dict(dtype=self._index.schema.fields[SESSION_VECTOR_FIELD_NAME].attrs.datatype)) # type: ignore[union-attr] self._index.load(data=chat_messages, id_field=ID_FIELD_NAME) diff --git a/redisvl/extensions/session_manager/standard_session.py b/redisvl/extensions/session_manager/standard_session.py index 9680133f..5a5db108 100644 --- a/redisvl/extensions/session_manager/standard_session.py +++ b/redisvl/extensions/session_manager/standard_session.py @@ -122,7 +122,7 @@ def get_recent( raw: bool = False, session_tag: Optional[str] = None, ) -> Union[List[str], List[Dict[str, str]]]: - """Retreive the recent conversation history in sequential order. + """Retrieve the recent conversation history in sequential order. Args: top_k (int): The number of previous messages to return. Default is 5. diff --git a/redisvl/query/query.py b/redisvl/query/query.py index 9ba05481..856a2572 100644 --- a/redisvl/query/query.py +++ b/redisvl/query/query.py @@ -1,6 +1,5 @@ from typing import Any, Dict, List, Optional, Union -import numpy as np from redis.commands.search.query import Query as RedisQuery from redisvl.query.filter import FilterExpression @@ -169,10 +168,6 @@ def _build_query_string(self) -> str: class BaseVectorQuery: - DTYPES: Dict[str, Any] = { - "float32": np.float32, - "float64": np.float64, - } DISTANCE_ID: str = "vector_distance" VECTOR_PARAM: str = "vector" @@ -264,7 +259,7 @@ def params(self) -> Dict[str, Any]: if isinstance(self._vector, bytes): vector = self._vector else: - vector = array_to_buffer(self._vector, dtype=self.DTYPES[self._dtype]) + vector = array_to_buffer(self._vector, dtype=self._dtype) return {self.VECTOR_PARAM: vector} @@ -390,7 +385,7 @@ def params(self) -> Dict[str, Any]: if isinstance(self._vector, bytes): vector_param = self._vector else: - vector_param = array_to_buffer(self._vector, dtype=self.DTYPES[self._dtype]) + vector_param = array_to_buffer(self._vector, dtype=self._dtype) return { self.VECTOR_PARAM: vector_param, diff --git a/redisvl/redis/utils.py b/redisvl/redis/utils.py index 28d15509..c619a11d 100644 --- a/redisvl/redis/utils.py +++ b/redisvl/redis/utils.py @@ -2,6 +2,9 @@ from typing import Any, Dict, List, Optional import numpy as np +from ml_dtypes import bfloat16 + +from redisvl.schema.fields import VectorDataType def make_dict(values: List[Any]) -> Dict[Any, Any]: @@ -30,14 +33,26 @@ def convert_bytes(data: Any) -> Any: return data -def array_to_buffer(array: List[float], dtype: Any = np.float32) -> bytes: +def array_to_buffer(array: List[float], dtype: str) -> bytes: """Convert a list of floats into a numpy byte string.""" - return np.array(array).astype(dtype).tobytes() + try: + VectorDataType(dtype.upper()) + except ValueError: + raise ValueError( + f"Invalid data type: {dtype}. Supported types are: {[t.lower() for t in VectorDataType]}" + ) + return np.array(array).astype(dtype.lower()).tobytes() -def buffer_to_array(buffer: bytes, dtype: Any = np.float32) -> List[float]: +def buffer_to_array(buffer: bytes, dtype: str) -> List[float]: """Convert bytes into into a list of floats.""" - return np.frombuffer(buffer, dtype=dtype).tolist() + try: + VectorDataType(dtype.upper()) + except ValueError: + raise ValueError( + f"Invalid data type: {dtype}. Supported types are: {[t.lower() for t in VectorDataType]}" + ) + return np.frombuffer(buffer, dtype=dtype.lower()).tolist() def hashify(content: str, extras: Optional[Dict[str, Any]] = None) -> str: diff --git a/redisvl/schema/fields.py b/redisvl/schema/fields.py index 7dd85bea..132e785f 100644 --- a/redisvl/schema/fields.py +++ b/redisvl/schema/fields.py @@ -26,6 +26,8 @@ class VectorDistanceMetric(str, Enum): class VectorDataType(str, Enum): + BFLOAT16 = "BFLOAT16" + FLOAT16 = "FLOAT16" FLOAT32 = "FLOAT32" FLOAT64 = "FLOAT64" diff --git a/redisvl/utils/vectorize/base.py b/redisvl/utils/vectorize/base.py index 858e94b1..5fcd1b4a 100644 --- a/redisvl/utils/vectorize/base.py +++ b/redisvl/utils/vectorize/base.py @@ -81,7 +81,11 @@ def batchify(self, seq: list, size: int, preprocess: Optional[Callable] = None): else: yield seq[pos : pos + size] - def _process_embedding(self, embedding: List[float], as_buffer: bool): + def _process_embedding(self, embedding: List[float], as_buffer: bool, **kwargs): if as_buffer: - return array_to_buffer(embedding) + if "dtype" not in kwargs: + raise RuntimeError( + "dtype is required if converting from float to byte string." + ) + return array_to_buffer(embedding, kwargs["dtype"]) return embedding diff --git a/redisvl/utils/vectorize/text/azureopenai.py b/redisvl/utils/vectorize/text/azureopenai.py index 734fef5b..3129c0b0 100644 --- a/redisvl/utils/vectorize/text/azureopenai.py +++ b/redisvl/utils/vectorize/text/azureopenai.py @@ -194,7 +194,8 @@ def embed_many( for batch in self.batchify(texts, batch_size, preprocess): response = self._client.embeddings.create(input=batch, model=self.model) embeddings += [ - self._process_embedding(r.embedding, as_buffer) for r in response.data + self._process_embedding(r.embedding, as_buffer, **kwargs) + for r in response.data ] return embeddings @@ -231,7 +232,7 @@ def embed( if preprocess: text = preprocess(text) result = self._client.embeddings.create(input=[text], model=self.model) - return self._process_embedding(result.data[0].embedding, as_buffer) + return self._process_embedding(result.data[0].embedding, as_buffer, **kwargs) @retry( wait=wait_random_exponential(min=1, max=60), @@ -274,7 +275,8 @@ async def aembed_many( input=batch, model=self.model ) embeddings += [ - self._process_embedding(r.embedding, as_buffer) for r in response.data + self._process_embedding(r.embedding, as_buffer, **kwargs) + for r in response.data ] return embeddings @@ -311,7 +313,7 @@ async def aembed( if preprocess: text = preprocess(text) result = await self._aclient.embeddings.create(input=[text], model=self.model) - return self._process_embedding(result.data[0].embedding, as_buffer) + return self._process_embedding(result.data[0].embedding, as_buffer, **kwargs) @property def type(self) -> str: diff --git a/redisvl/utils/vectorize/text/cohere.py b/redisvl/utils/vectorize/text/cohere.py index 462783a1..94584b91 100644 --- a/redisvl/utils/vectorize/text/cohere.py +++ b/redisvl/utils/vectorize/text/cohere.py @@ -160,7 +160,7 @@ def embed( embedding = self._client.embed( texts=[text], model=self.model, input_type=input_type ).embeddings[0] - return self._process_embedding(embedding, as_buffer) + return self._process_embedding(embedding, as_buffer, **kwargs) @retry( wait=wait_random_exponential(min=1, max=60), @@ -230,7 +230,7 @@ def embed_many( texts=batch, model=self.model, input_type=input_type ) embeddings += [ - self._process_embedding(embedding, as_buffer) + self._process_embedding(embedding, as_buffer, **kwargs) for embedding in response.embeddings ] return embeddings diff --git a/redisvl/utils/vectorize/text/custom.py b/redisvl/utils/vectorize/text/custom.py index bf0eec4a..8dc42c12 100644 --- a/redisvl/utils/vectorize/text/custom.py +++ b/redisvl/utils/vectorize/text/custom.py @@ -174,7 +174,7 @@ def embed( text = preprocess(text) else: result = self._embed_func(text, **kwargs) - return self._process_embedding(result, as_buffer) + return self._process_embedding(result, as_buffer, **kwargs) def embed_many( self, @@ -213,7 +213,9 @@ def embed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): results = self._embed_many_func(batch, **kwargs) - embeddings += [self._process_embedding(r, as_buffer) for r in results] + embeddings += [ + self._process_embedding(r, as_buffer, **kwargs) for r in results + ] return embeddings async def aembed( @@ -249,7 +251,7 @@ async def aembed( text = preprocess(text) else: result = await self._aembed_func(text, **kwargs) - return self._process_embedding(result, as_buffer) + return self._process_embedding(result, as_buffer, **kwargs) async def aembed_many( self, @@ -288,7 +290,9 @@ async def aembed_many( embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): results = await self._aembed_many_func(batch, **kwargs) - embeddings += [self._process_embedding(r, as_buffer) for r in results] + embeddings += [ + self._process_embedding(r, as_buffer, **kwargs) for r in results + ] return embeddings @property diff --git a/redisvl/utils/vectorize/text/huggingface.py b/redisvl/utils/vectorize/text/huggingface.py index ab983ffe..d0a6243d 100644 --- a/redisvl/utils/vectorize/text/huggingface.py +++ b/redisvl/utils/vectorize/text/huggingface.py @@ -100,7 +100,7 @@ def embed( if preprocess: text = preprocess(text) embedding = self._client.encode([text])[0] - return self._process_embedding(embedding.tolist(), as_buffer) + return self._process_embedding(embedding.tolist(), as_buffer, **kwargs) def embed_many( self, @@ -138,7 +138,7 @@ def embed_many( batch_embeddings = self._client.encode(batch) embeddings.extend( [ - self._process_embedding(embedding.tolist(), as_buffer) + self._process_embedding(embedding.tolist(), as_buffer, **kwargs) for embedding in batch_embeddings ] ) diff --git a/redisvl/utils/vectorize/text/mistral.py b/redisvl/utils/vectorize/text/mistral.py index 8776ef3d..7d4f00f5 100644 --- a/redisvl/utils/vectorize/text/mistral.py +++ b/redisvl/utils/vectorize/text/mistral.py @@ -144,7 +144,8 @@ def embed_many( for batch in self.batchify(texts, batch_size, preprocess): response = self._client.embeddings(model=self.model, input=batch) embeddings += [ - self._process_embedding(r.embedding, as_buffer) for r in response.data + self._process_embedding(r.embedding, as_buffer, **kwargs) + for r in response.data ] return embeddings @@ -181,7 +182,7 @@ def embed( if preprocess: text = preprocess(text) result = self._client.embeddings(model=self.model, input=[text]) - return self._process_embedding(result.data[0].embedding, as_buffer) + return self._process_embedding(result.data[0].embedding, as_buffer, **kwargs) @retry( wait=wait_random_exponential(min=1, max=60), @@ -222,7 +223,8 @@ async def aembed_many( for batch in self.batchify(texts, batch_size, preprocess): response = await self._aclient.embeddings(model=self.model, input=batch) embeddings += [ - self._process_embedding(r.embedding, as_buffer) for r in response.data + self._process_embedding(r.embedding, as_buffer, **kwargs) + for r in response.data ] return embeddings @@ -259,7 +261,7 @@ async def aembed( if preprocess: text = preprocess(text) result = await self._aclient.embeddings(model=self.model, input=[text]) - return self._process_embedding(result.data[0].embedding, as_buffer) + return self._process_embedding(result.data[0].embedding, as_buffer, **kwargs) @property def type(self) -> str: diff --git a/redisvl/utils/vectorize/text/openai.py b/redisvl/utils/vectorize/text/openai.py index 5921bda8..ae5d19dc 100644 --- a/redisvl/utils/vectorize/text/openai.py +++ b/redisvl/utils/vectorize/text/openai.py @@ -148,7 +148,8 @@ def embed_many( for batch in self.batchify(texts, batch_size, preprocess): response = self._client.embeddings.create(input=batch, model=self.model) embeddings += [ - self._process_embedding(r.embedding, as_buffer) for r in response.data + self._process_embedding(r.embedding, as_buffer, **kwargs) + for r in response.data ] return embeddings @@ -185,7 +186,7 @@ def embed( if preprocess: text = preprocess(text) result = self._client.embeddings.create(input=[text], model=self.model) - return self._process_embedding(result.data[0].embedding, as_buffer) + return self._process_embedding(result.data[0].embedding, as_buffer, **kwargs) @retry( wait=wait_random_exponential(min=1, max=60), @@ -228,7 +229,8 @@ async def aembed_many( input=batch, model=self.model ) embeddings += [ - self._process_embedding(r.embedding, as_buffer) for r in response.data + self._process_embedding(r.embedding, as_buffer, **kwargs) + for r in response.data ] return embeddings @@ -265,7 +267,7 @@ async def aembed( if preprocess: text = preprocess(text) result = await self._aclient.embeddings.create(input=[text], model=self.model) - return self._process_embedding(result.data[0].embedding, as_buffer) + return self._process_embedding(result.data[0].embedding, as_buffer, **kwargs) @property def type(self) -> str: diff --git a/redisvl/utils/vectorize/text/vertexai.py b/redisvl/utils/vectorize/text/vertexai.py index 2ab9b83b..71e2e433 100644 --- a/redisvl/utils/vectorize/text/vertexai.py +++ b/redisvl/utils/vectorize/text/vertexai.py @@ -155,7 +155,7 @@ def embed_many( for batch in self.batchify(texts, batch_size, preprocess): response = self._client.get_embeddings(batch) embeddings += [ - self._process_embedding(r.values, as_buffer) for r in response + self._process_embedding(r.values, as_buffer, **kwargs) for r in response ] return embeddings @@ -192,7 +192,7 @@ def embed( if preprocess: text = preprocess(text) result = self._client.get_embeddings([text]) - return self._process_embedding(result[0].values, as_buffer) + return self._process_embedding(result[0].values, as_buffer, **kwargs) @property def type(self) -> str: diff --git a/tests/integration/test_flow.py b/tests/integration/test_flow.py index 538b02d3..b448a636 100644 --- a/tests/integration/test_flow.py +++ b/tests/integration/test_flow.py @@ -51,7 +51,10 @@ def test_simple(client, schema, sample_data): # Prepare and load the data based on storage type def hash_preprocess(item: dict) -> dict: - return {**item, "user_embedding": array_to_buffer(item["user_embedding"])} + return { + **item, + "user_embedding": array_to_buffer(item["user_embedding"], "float32"), + } if index.storage_type == StorageType.HASH: index.load(sample_data, preprocess=hash_preprocess, id_field="user") diff --git a/tests/integration/test_flow_async.py b/tests/integration/test_flow_async.py index 3557ded8..fbfa7d22 100644 --- a/tests/integration/test_flow_async.py +++ b/tests/integration/test_flow_async.py @@ -55,7 +55,10 @@ async def test_simple(async_client, schema, sample_data): # Prepare and load the data based on storage type async def hash_preprocess(item: dict) -> dict: - return {**item, "user_embedding": array_to_buffer(item["user_embedding"])} + return { + **item, + "user_embedding": array_to_buffer(item["user_embedding"], "float32"), + } if index.storage_type == StorageType.HASH: await index.load(sample_data, preprocess=hash_preprocess, id_field="user") diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index 6c106e87..4eb4d6f7 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -1,4 +1,5 @@ import asyncio +import os from collections import namedtuple from time import sleep, time @@ -821,7 +822,6 @@ def test_no_key_collision_on_identical_prompts(redis_url): private_cache.store( prompt="What's the phone number linked in my account?", response="The number on file is 123-555-9999", - ###filters={"user_id": "cerioni"}, filters={"user_id": "cerioni", "zip_code": 90210}, ) @@ -843,3 +843,35 @@ def test_no_key_collision_on_identical_prompts(redis_url): filter_expression=zip_code_filter, ) assert len(filtered_results) == 2 + + +def test_create_cache_with_different_vector_types(): + try: + bfloat_cache = SemanticCache(name="bfloat_cache", dtype="bfloat16") + bfloat_cache.store("bfloat16 prompt", "bfloat16 response") + + float16_cache = SemanticCache(name="float16_cache", dtype="float16") + float16_cache.store("float16 prompt", "float16 response") + + float32_cache = SemanticCache(name="float32_cache", dtype="float32") + float32_cache.store("float32 prompt", "float32 response") + + float64_cache = SemanticCache(name="float64_cache", dtype="float64") + float64_cache.store("float64 prompt", "float64 response") + + for cache in [bfloat_cache, float16_cache, float32_cache, float64_cache]: + cache.set_threshold(0.6) + assert len(cache.check("float prompt", num_results=5)) == 1 + except: + pytest.skip("Not using a late enough version of Redis") + + +def test_bad_dtype_connecting_to_existing_cache(): + try: + cache = SemanticCache(name="float64_cache", dtype="float64") + same_type = SemanticCache(name="float64_cache", dtype="float64") + except ValueError: + pytest.skip("Not using a late enough version of Redis") + + with pytest.raises(ValueError): + bad_type = SemanticCache(name="float64_cache", dtype="float16") diff --git a/tests/integration/test_query.py b/tests/integration/test_query.py index df348f83..752d5fe7 100644 --- a/tests/integration/test_query.py +++ b/tests/integration/test_query.py @@ -103,7 +103,10 @@ def index(sample_data, redis_url): # Prepare and load the data def hash_preprocess(item: dict) -> dict: - return {**item, "user_embedding": array_to_buffer(item["user_embedding"])} + return { + **item, + "user_embedding": array_to_buffer(item["user_embedding"], "float32"), + } index.load(sample_data, preprocess=hash_preprocess) diff --git a/tests/integration/test_semantic_router.py b/tests/integration/test_semantic_router.py index b2a7c716..194a6f98 100644 --- a/tests/integration/test_semantic_router.py +++ b/tests/integration/test_semantic_router.py @@ -1,3 +1,4 @@ +import os import pathlib import pytest @@ -175,7 +176,7 @@ def test_to_dict(semantic_router): def test_from_dict(semantic_router): router_dict = semantic_router.to_dict() new_router = SemanticRouter.from_dict( - router_dict, redis_client=semantic_router._index.client + router_dict, redis_client=semantic_router._index.client, overwrite=True ) assert new_router == semantic_router @@ -223,7 +224,7 @@ def test_yaml_invalid_file_path(): def test_idempotent_to_dict(semantic_router): router_dict = semantic_router.to_dict() new_router = SemanticRouter.from_dict( - router_dict, redis_client=semantic_router._index.client + router_dict, redis_client=semantic_router._index.client, overwrite=True ) assert new_router.to_dict() == router_dict @@ -237,3 +238,59 @@ def test_bad_connection_info(routes): redis_url="redis://localhost:6389", # bad connection url overwrite=False, ) + + +def test_different_vector_dtypes(routes): + try: + bfloat_router = SemanticRouter( + name="bfloat_router", + routes=routes, + dtype="bfloat16", + ) + + float16_router = SemanticRouter( + name="float16_router", + routes=routes, + dtype="float16", + ) + + float32_router = SemanticRouter( + name="float32_router", + routes=routes, + dtype="float32", + ) + + float64_router = SemanticRouter( + name="float64_router", + routes=routes, + dtype="float64", + ) + + for router in [bfloat_router, float16_router, float32_router, float64_router]: + assert len(router.route_many("hello", max_k=5)) == 1 + except: + pytest.skip("Not using a late enough version of Redis") + + +def test_bad_dtype_connecting_to_exiting_router(routes): + try: + router = SemanticRouter( + name="float64 router", + routes=routes, + dtype="float64", + ) + + same_type = SemanticRouter( + name="float64 router", + routes=routes, + dtype="float64", + ) + except ValueError: + pytest.skip("Not using a late enough version of Redis") + + with pytest.raises(ValueError): + bad_type = SemanticRouter( + name="float64 router", + routes=routes, + dtype="float16", + ) diff --git a/tests/integration/test_session_manager.py b/tests/integration/test_session_manager.py index 3f458a50..898c3ab5 100644 --- a/tests/integration/test_session_manager.py +++ b/tests/integration/test_session_manager.py @@ -1,3 +1,5 @@ +import os + import pytest from redis.exceptions import ConnectionError @@ -17,7 +19,7 @@ def standard_session(app_name, client): @pytest.fixture def semantic_session(app_name, client): - session = SemanticSessionManager(app_name, redis_client=client) + session = SemanticSessionManager(app_name, redis_client=client, overwrite=True) yield session session.clear() session.delete() @@ -285,7 +287,7 @@ def test_standard_clear(standard_session): # test semantic session manager def test_semantic_specify_client(client): session = SemanticSessionManager( - name="test_app", session_tag="abc", redis_client=client + name="test_app", session_tag="abc", redis_client=client, overwrite=True ) assert isinstance(session._index.client, type(client)) @@ -537,3 +539,35 @@ def test_semantic_drop(semantic_session): {"role": "llm", "content": "third response"}, {"role": "user", "content": "fourth prompt"}, ] + + +def test_different_vector_dtypes(): + try: + bfloat_sess = SemanticSessionManager(name="bfloat_session", dtype="bfloat16") + bfloat_sess.add_message({"role": "user", "content": "bfloat message"}) + + float16_sess = SemanticSessionManager(name="float16_session", dtype="float16") + float16_sess.add_message({"role": "user", "content": "float16 message"}) + + float32_sess = SemanticSessionManager(name="float32_session", dtype="float32") + float32_sess.add_message({"role": "user", "content": "float32 message"}) + + float64_sess = SemanticSessionManager(name="float64_session", dtype="float64") + float64_sess.add_message({"role": "user", "content": "float64 message"}) + + for sess in [bfloat_sess, float16_sess, float32_sess, float64_sess]: + sess.set_distance_threshold(0.7) + assert len(sess.get_relevant("float message")) == 1 + except: + pytest.skip("Not using a late enough version of Redis") + + +def test_bad_dtype_connecting_to_exiting_session(): + try: + session = SemanticSessionManager(name="float64 session", dtype="float64") + same_type = SemanticSessionManager(name="float64 session", dtype="float64") + except ValueError: + pytest.skip("Not using a late enough version of Redis") + + with pytest.raises(ValueError): + bad_type = SemanticSessionManager(name="float64 session", dtype="float16") diff --git a/tests/unit/test_llmcache_schema.py b/tests/unit/test_llmcache_schema.py index aa3a3add..a686ebe0 100644 --- a/tests/unit/test_llmcache_schema.py +++ b/tests/unit/test_llmcache_schema.py @@ -47,10 +47,10 @@ def test_cache_entry_to_dict(): metadata={"author": "John"}, filters={"category": "technology"}, ) - result = entry.to_dict() + result = entry.to_dict(dtype="float32") assert result["entry_id"] == hashify("What is AI?", {"category": "technology"}) assert result["metadata"] == json.dumps({"author": "John"}) - assert result["prompt_vector"] == array_to_buffer([0.1, 0.2, 0.3]) + assert result["prompt_vector"] == array_to_buffer([0.1, 0.2, 0.3], "float32") assert result["category"] == "technology" assert "filters" not in result @@ -109,7 +109,7 @@ def test_cache_entry_with_empty_optional_fields(): response="AI is artificial intelligence.", prompt_vector=[0.1, 0.2, 0.3], ) - result = entry.to_dict() + result = entry.to_dict(dtype="float32") assert "metadata" not in result assert "filters" not in result diff --git a/tests/unit/test_session_schema.py b/tests/unit/test_session_schema.py index b25f6564..3beb098c 100644 --- a/tests/unit/test_session_schema.py +++ b/tests/unit/test_session_schema.py @@ -96,14 +96,14 @@ def test_chat_message_to_dict(): vector_field=vector_field, ) - data = chat_message.to_dict() + data = chat_message.to_dict(dtype="float32") assert data["entry_id"] == f"{session_tag}:{timestamp}" assert data["role"] == "user" assert data["content"] == content assert data["session_tag"] == session_tag assert data["timestamp"] == timestamp - assert data["vector_field"] == array_to_buffer(vector_field) + assert data["vector_field"] == array_to_buffer(vector_field, "float32") def test_chat_message_missing_fields(): diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index ca535c5a..d4ffaaaf 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,5 +1,6 @@ import numpy as np import pytest +from ml_dtypes import bfloat16 from redisvl.redis.utils import ( array_to_buffer, @@ -50,27 +51,22 @@ def test_simple_byte_buffer_to_floats(): """Test conversion of a simple byte buffer into floats""" buffer = np.array([1.0, 2.0, 3.0], dtype=np.float32).tobytes() expected = [1.0, 2.0, 3.0] - assert buffer_to_array(buffer, dtype=np.float32) == expected + assert buffer_to_array(buffer, dtype="float32") == expected -def test_different_data_types(): +def test_converting_different_data_types(): """Test conversion with different data types""" - # Integer test - buffer = np.array([1, 2, 3], dtype=np.int32).tobytes() - expected = [1, 2, 3] - assert buffer_to_array(buffer, dtype=np.int32) == expected - # Float64 test buffer = np.array([1.0, 2.0, 3.0], dtype=np.float64).tobytes() expected = [1.0, 2.0, 3.0] - assert buffer_to_array(buffer, dtype=np.float64) == expected + assert buffer_to_array(buffer, dtype="float64") == expected def test_empty_byte_buffer(): """Test conversion of an empty byte buffer""" buffer = b"" expected = [] - assert buffer_to_array(buffer, dtype=np.float32) == expected + assert buffer_to_array(buffer, dtype="float32") == expected def test_plain_bytes_to_string(): @@ -119,7 +115,7 @@ def test_simple_list_to_bytes_default_dtype(): """Test conversion of a simple list of floats to bytes using the default dtype""" array = [1.0, 2.0, 3.0] expected = np.array(array, dtype=np.float32).tobytes() - assert array_to_buffer(array) == expected + assert array_to_buffer(array, "float32") == expected def test_list_to_bytes_non_default_dtype(): @@ -127,17 +123,17 @@ def test_list_to_bytes_non_default_dtype(): array = [1.0, 2.0, 3.0] dtype = np.float64 expected = np.array(array, dtype=dtype).tobytes() - assert array_to_buffer(array, dtype=dtype) == expected + assert array_to_buffer(array, dtype="float64") == expected def test_empty_list_to_bytes(): """Test conversion of an empty list""" array = [] expected = np.array(array, dtype=np.float32).tobytes() - assert array_to_buffer(array) == expected + assert array_to_buffer(array, dtype="float32") == expected -@pytest.mark.parametrize("dtype", [np.int32, np.float64]) +@pytest.mark.parametrize("dtype", ["float64", "float32", "float16", "bfloat16"]) def test_conversion_with_various_dtypes(dtype): """Test conversion of a list of floats to bytes with various dtypes""" array = [1.0, -2.0, 3.5] @@ -148,5 +144,5 @@ def test_conversion_with_various_dtypes(dtype): def test_conversion_with_invalid_floats(): """Test conversion with invalid float values (numpy should handle them)""" array = [float("inf"), float("-inf"), float("nan")] - result = array_to_buffer(array) + result = array_to_buffer(array, "float16") assert len(result) > 0 # Simple check to ensure it returns anything