Skip to content

Commit ed557f6

Browse files
puts dtype test in loop
1 parent 2281d35 commit ed557f6

File tree

1 file changed

+10
-18
lines changed

1 file changed

+10
-18
lines changed

tests/integration/test_vectorizers.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
2-
import numpy as np
32

3+
import numpy as np
44
import pytest
55

66
from redisvl.redis.utils import buffer_to_array
@@ -239,28 +239,20 @@ def bad_return_type(text: str) -> str:
239239
custom_embed_func, embed_many=bad_return_type
240240
)
241241

242+
242243
def test_dtypes(vectorizer):
243244
if isinstance(vectorizer, CustomTextVectorizer):
244245
pytest.skip("skipping custom text vectorizer")
245246
words = "hello"
246-
raw = vectorizer.embed(words, as_buffer=False)
247-
248-
default = vectorizer.embed(words, as_buffer=True)
249-
assert buffer_to_array(default, dtype="float32") == raw
247+
raw = vectorizer.embed(words, as_buffer=False, input_type="search_query")
250248

251-
float16 = vectorizer.embed(words, as_buffer=True, dtype="float16")
252-
# assert buffer_to_array(float16, dtype="float16") == raw # fails
253-
assert np.allclose(buffer_to_array(float16, dtype="float16"), raw, atol=1e-03)
254-
255-
float32 = vectorizer.embed(words, as_buffer=True, dtype="float32")
256-
assert buffer_to_array(float32, dtype="float32") == raw
257-
258-
float64 = vectorizer.embed(words, as_buffer=True, dtype="float64")
259-
assert buffer_to_array(float64, dtype="float64") == raw
260-
261-
bfloat16 = vectorizer.embed(words, as_buffer=True, dtype="bfloat16")
262-
# assert buffer_to_array(bfloat16, dtype="bfloat16") == raw # fails
263-
assert np.allclose(buffer_to_array(bfloat16, dtype="bfloat16"), raw, atol=1e-03)
249+
default = vectorizer.embed(words, as_buffer=True, input_type="search_query")
250+
assert np.allclose(buffer_to_array(default, dtype="float32"), raw, atol=1e-03)
251+
for dtype in ["float16", "float32", "float64", "bfloat16"]:
252+
embedding = vectorizer.embed(
253+
words, as_buffer=True, dtype=dtype, input_type="search_query"
254+
)
255+
assert np.allclose(buffer_to_array(embedding, dtype=dtype), raw, atol=1e-03)
264256

265257

266258
@pytest.fixture(

0 commit comments

Comments
 (0)