Skip to content

Commit c105447

Browse files
adds tests for setting dtype in vectorizer
1 parent da4b129 commit c105447

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

tests/integration/test_vectorizers.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import os
22

3+
import numpy as np
34
import pytest
45

6+
from redisvl.redis.utils import buffer_to_array
57
from redisvl.utils.vectorize import (
68
AzureOpenAITextVectorizer,
79
BedrockTextVectorizer,
@@ -238,6 +240,29 @@ def bad_return_type(text: str) -> str:
238240
)
239241

240242

243+
def test_dtypes(vectorizer):
244+
words = "hello"
245+
246+
raw = vectorizer.embed(words, as_buffer=False)
247+
248+
default = vectorizer.embed(words, as_buffer=True)
249+
assert buffer_to_array(default, dtype="float32") == raw
250+
251+
float16 = vectorizer.embed(words, as_buffer=True, dtype="float16")
252+
# assert buffer_to_array(float16, dtype="float16") == raw # fails
253+
assert np.allclose(buffer_to_array(float16, dtype="float16"), raw, atol=1e-04)
254+
255+
float32 = vectorizer.embed(words, as_buffer=True, dtype="float32")
256+
assert buffer_to_array(float32, dtype="float32") == raw
257+
258+
float64 = vectorizer.embed(words, as_buffer=True, dtype="float64")
259+
assert buffer_to_array(float64, dtype="float64") == raw
260+
261+
bfloat16 = vectorizer.embed(words, as_buffer=True, dtype="bfloat16")
262+
# assert buffer_to_array(bfloat16, dtype="bfloat16") == raw # fails
263+
assert np.allclose(buffer_to_array(bfloat16, dtype="bfloat16"), raw, atol=1e-03)
264+
265+
241266
@pytest.fixture(
242267
params=[
243268
OpenAITextVectorizer,

0 commit comments

Comments
 (0)