|
1 | 1 | import os |
2 | | -import numpy as np |
3 | 2 |
|
| 3 | +import numpy as np |
4 | 4 | import pytest |
5 | 5 |
|
6 | 6 | from redisvl.redis.utils import buffer_to_array |
@@ -239,28 +239,20 @@ def bad_return_type(text: str) -> str: |
239 | 239 | custom_embed_func, embed_many=bad_return_type |
240 | 240 | ) |
241 | 241 |
|
| 242 | + |
242 | 243 | def test_dtypes(vectorizer): |
243 | 244 | if isinstance(vectorizer, CustomTextVectorizer): |
244 | 245 | pytest.skip("skipping custom text vectorizer") |
245 | 246 | words = "hello" |
246 | | - raw = vectorizer.embed(words, as_buffer=False) |
247 | | - |
248 | | - default = vectorizer.embed(words, as_buffer=True) |
249 | | - assert buffer_to_array(default, dtype="float32") == raw |
| 247 | + raw = vectorizer.embed(words, as_buffer=False, input_type="search_query") |
250 | 248 |
|
251 | | - float16 = vectorizer.embed(words, as_buffer=True, dtype="float16") |
252 | | - # assert buffer_to_array(float16, dtype="float16") == raw # fails |
253 | | - assert np.allclose(buffer_to_array(float16, dtype="float16"), raw, atol=1e-03) |
254 | | - |
255 | | - float32 = vectorizer.embed(words, as_buffer=True, dtype="float32") |
256 | | - assert buffer_to_array(float32, dtype="float32") == raw |
257 | | - |
258 | | - float64 = vectorizer.embed(words, as_buffer=True, dtype="float64") |
259 | | - assert buffer_to_array(float64, dtype="float64") == raw |
260 | | - |
261 | | - bfloat16 = vectorizer.embed(words, as_buffer=True, dtype="bfloat16") |
262 | | - # assert buffer_to_array(bfloat16, dtype="bfloat16") == raw # fails |
263 | | - assert np.allclose(buffer_to_array(bfloat16, dtype="bfloat16"), raw, atol=1e-03) |
| 249 | + default = vectorizer.embed(words, as_buffer=True, input_type="search_query") |
| 250 | + assert np.allclose(buffer_to_array(default, dtype="float32"), raw, atol=1e-03) |
| 251 | + for dtype in ["float16", "float32", "float64", "bfloat16"]: |
| 252 | + embedding = vectorizer.embed( |
| 253 | + words, as_buffer=True, dtype=dtype, input_type="search_query" |
| 254 | + ) |
| 255 | + assert np.allclose(buffer_to_array(embedding, dtype=dtype), raw, atol=1e-03) |
264 | 256 |
|
265 | 257 |
|
266 | 258 | @pytest.fixture( |
|
0 commit comments