Skip to content

Commit 6898f1d

Browse files
removes dtype from class definitions, and uses constants instead
1 parent 41f5693 commit 6898f1d

File tree

6 files changed

+30
-34
lines changed

6 files changed

+30
-34
lines changed

redisvl/extensions/llmcache/schema.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@ class CacheEntry(BaseModel):
2626
"""Optional metadata stored on the cache entry"""
2727
filters: Optional[Dict[str, Any]] = Field(default=None)
2828
"""Optional filter data stored on the cache entry for customizing retrieval"""
29-
dtype: str
30-
"""The data type for the prompt vector."""
3129

3230
@root_validator(pre=True)
3331
@classmethod
@@ -43,9 +41,9 @@ def non_empty_metadata(cls, v):
4341
raise TypeError("Metadata must be a dictionary.")
4442
return v
4543

46-
def to_dict(self) -> Dict:
44+
def to_dict(self, dtype: str) -> Dict:
4745
data = self.dict(exclude_none=True)
48-
data["prompt_vector"] = array_to_buffer(self.prompt_vector, self.dtype)
46+
data["prompt_vector"] = array_to_buffer(self.prompt_vector, dtype)
4947
if self.metadata is not None:
5048
data["metadata"] = serialize(self.metadata)
5149
if self.filters is not None:

redisvl/extensions/llmcache/semantic.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from redisvl.utils.utils import current_timestamp, serialize, validate_vector_dims
1616
from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer
1717

18+
VECTOR_FIELD_NAME = "prompt_vector" ###
19+
1820

1921
class SemanticCache(BaseLLMCache):
2022
"""Semantic Cache for Large Language Models."""
@@ -23,7 +25,7 @@ class SemanticCache(BaseLLMCache):
2325
entry_id_field_name: str = "entry_id"
2426
prompt_field_name: str = "prompt"
2527
response_field_name: str = "response"
26-
vector_field_name: str = "prompt_vector"
28+
###vector_field_name: str = "prompt_vector"
2729
inserted_at_field_name: str = "inserted_at"
2830
updated_at_field_name: str = "updated_at"
2931
metadata_field_name: str = "metadata"
@@ -136,9 +138,10 @@ def __init__(
136138

137139
validate_vector_dims(
138140
vectorizer.dims,
139-
self._index.schema.fields[self.vector_field_name].attrs.dims, # type: ignore
141+
self._index.schema.fields[VECTOR_FIELD_NAME].attrs.dims, # type: ignore
140142
)
141143
self._vectorizer = vectorizer
144+
self._dtype = self.index.schema.fields[VECTOR_FIELD_NAME].attrs.datatype # type: ignore[union-attr]
142145

143146
def _modify_schema(
144147
self,
@@ -290,8 +293,7 @@ def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]:
290293
if not isinstance(prompt, str):
291294
raise TypeError("Prompt must be a string.")
292295

293-
dtype = self.index.schema.fields[self.vector_field_name].attrs.datatype # type: ignore[union-attr]
294-
return self._vectorizer.embed(prompt, dtype=dtype)
296+
return self._vectorizer.embed(prompt, dtype=self._dtype)
295297

296298
async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]:
297299
"""Converts a text prompt to its vector representation using the
@@ -304,7 +306,7 @@ async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]:
304306
def _check_vector_dims(self, vector: List[float]):
305307
"""Checks the size of the provided vector and raises an error if it
306308
doesn't match the search index vector dimensions."""
307-
schema_vector_dims = self._index.schema.fields[self.vector_field_name].attrs.dims # type: ignore
309+
schema_vector_dims = self._index.schema.fields[VECTOR_FIELD_NAME].attrs.dims # type: ignore
308310
validate_vector_dims(len(vector), schema_vector_dims)
309311

310312
def check(
@@ -367,13 +369,13 @@ def check(
367369

368370
query = RangeQuery(
369371
vector=vector,
370-
vector_field_name=self.vector_field_name,
372+
vector_field_name=VECTOR_FIELD_NAME,
371373
return_fields=self.return_fields,
372374
distance_threshold=distance_threshold,
373375
num_results=num_results,
374376
return_score=True,
375377
filter_expression=filter_expression,
376-
dtype=self.index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr]
378+
dtype=self._dtype,
377379
)
378380

379381
# Search the cache!
@@ -449,7 +451,7 @@ async def acheck(
449451

450452
query = RangeQuery(
451453
vector=vector,
452-
vector_field_name=self.vector_field_name,
454+
vector_field_name=VECTOR_FIELD_NAME,
453455
return_fields=self.return_fields,
454456
distance_threshold=distance_threshold,
455457
num_results=num_results,
@@ -539,13 +541,12 @@ def store(
539541
prompt_vector=vector,
540542
metadata=metadata,
541543
filters=filters,
542-
dtype=self.index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr]
543544
)
544545

545546
# Load cache entry with TTL
546547
ttl = ttl or self._ttl
547548
keys = self._index.load(
548-
data=[cache_entry.to_dict()],
549+
data=[cache_entry.to_dict(self._dtype)],
549550
ttl=ttl,
550551
id_field=self.entry_id_field_name,
551552
)
@@ -604,13 +605,12 @@ async def astore(
604605
prompt_vector=vector,
605606
metadata=metadata,
606607
filters=filters,
607-
dtype=self.index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr]
608608
)
609609

610610
# Load cache entry with TTL
611611
ttl = ttl or self._ttl
612612
keys = await aindex.load(
613-
data=[cache_entry.to_dict()],
613+
data=[cache_entry.to_dict(self._dtype)],
614614
ttl=ttl,
615615
id_field=self.entry_id_field_name,
616616
)

redisvl/extensions/router/semantic.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828

2929
logger = get_logger(__name__)
3030

31+
VECTOR_FIELD_NAME = "vector" ###
32+
3133

3234
class SemanticRouter(BaseModel):
3335
"""Semantic Router for managing and querying route vectors."""
@@ -40,7 +42,7 @@ class SemanticRouter(BaseModel):
4042
"""The vectorizer used to embed route references."""
4143
routing_config: RoutingConfig = Field(default_factory=RoutingConfig)
4244
"""Configuration for routing behavior."""
43-
vector_field_name: str = "vector"
45+
### vector_field_name: str = "vector"
4446

4547
_index: SearchIndex = PrivateAttr()
4648

@@ -171,7 +173,7 @@ def _add_routes(self, routes: List[Route]):
171173
reference_vectors = self.vectorizer.embed_many(
172174
[reference for reference in route.references],
173175
as_buffer=True,
174-
dtype=self._index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr]
176+
dtype=self._index.schema.fields[VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr]
175177
)
176178
# set route references
177179
for i, reference in enumerate(route.references):
@@ -248,7 +250,7 @@ def _classify_route(
248250
vector_field_name="vector",
249251
distance_threshold=distance_threshold,
250252
return_fields=["route_name"],
251-
dtype=self._index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr]
253+
dtype=self._index.schema.fields[VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr]
252254
)
253255

254256
aggregate_request = self._build_aggregate_request(
@@ -301,7 +303,7 @@ def _classify_multi_route(
301303
vector_field_name="vector",
302304
distance_threshold=distance_threshold,
303305
return_fields=["route_name"],
304-
dtype=self._index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr]
306+
dtype=self._index.schema.fields[VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr]
305307
)
306308
aggregate_request = self._build_aggregate_request(
307309
vector_range_query, aggregation_method, max_k

redisvl/extensions/session_manager/semantic_session.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
from redisvl.utils.utils import validate_vector_dims
1414
from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer
1515

16+
VECTOR_FIELD_NAME = "vector_field" ###
17+
1618

1719
class SemanticSessionManager(BaseSessionManager):
18-
vector_field_name: str = "vector_field"
20+
###vector_field_name: str = "vector_field"
1921

2022
def __init__(
2123
self,
@@ -201,13 +203,13 @@ def get_relevant(
201203

202204
query = RangeQuery(
203205
vector=self._vectorizer.embed(prompt),
204-
vector_field_name=self.vector_field_name,
206+
vector_field_name=VECTOR_FIELD_NAME,
205207
return_fields=return_fields,
206208
distance_threshold=distance_threshold,
207209
num_results=top_k,
208210
return_score=True,
209211
filter_expression=session_filter,
210-
dtype=self._index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr]
212+
dtype=self._index.schema.fields[VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr]
211213
)
212214
messages = self._index.query(query)
213215

@@ -321,15 +323,15 @@ def add_messages(
321323
content_vector = self._vectorizer.embed(message[self.content_field_name])
322324
validate_vector_dims(
323325
len(content_vector),
324-
self._index.schema.fields[self.vector_field_name].attrs.dims, # type: ignore
326+
self._index.schema.fields[VECTOR_FIELD_NAME].attrs.dims, # type: ignore
325327
)
326328

327329
chat_message = ChatMessage(
328330
role=message[self.role_field_name],
329331
content=message[self.content_field_name],
330332
session_tag=session_tag,
331333
vector_field=content_vector,
332-
dtype=self._index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr]
334+
dtype=self._index.schema.fields[VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr]
333335
)
334336

335337
if self.tool_field_name in message:

redisvl/index/storage.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import uuid
33
from typing import Any, Callable, Dict, Iterable, List, Optional
44

5-
from numpy import frombuffer
65
from pydantic.v1 import BaseModel
76
from redis import Redis
87
from redis.asyncio import Redis as AsyncRedis
@@ -394,7 +393,7 @@ class HashStorage(BaseStorage):
394393
"""Hash data type for the index"""
395394

396395
def _validate(self, obj: Dict[str, Any]):
397-
"""Validate that the given object is a dictionary suitable for storage
396+
"""Validate that the given object is a dictionary, suitable for storage
398397
as a Redis hash.
399398
400399
Args:

tests/unit/test_llmcache_schema.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ def test_valid_cache_entry_creation():
1212
prompt="What is AI?",
1313
response="AI is artificial intelligence.",
1414
prompt_vector=[0.1, 0.2, 0.3],
15-
dtype="float16",
1615
)
1716
assert entry.entry_id == hashify("What is AI?")
1817
assert entry.prompt == "What is AI?"
@@ -26,7 +25,6 @@ def test_cache_entry_with_given_entry_id():
2625
prompt="What is AI?",
2726
response="AI is artificial intelligence.",
2827
prompt_vector=[0.1, 0.2, 0.3],
29-
dtype="float16",
3028
)
3129
assert entry.entry_id == "custom_id"
3230

@@ -38,7 +36,6 @@ def test_cache_entry_with_invalid_metadata():
3836
response="AI is artificial intelligence.",
3937
prompt_vector=[0.1, 0.2, 0.3],
4038
metadata="invalid_metadata",
41-
dtype="float64",
4239
)
4340

4441

@@ -49,9 +46,8 @@ def test_cache_entry_to_dict():
4946
prompt_vector=[0.1, 0.2, 0.3],
5047
metadata={"author": "John"},
5148
filters={"category": "technology"},
52-
dtype="float32",
5349
)
54-
result = entry.to_dict()
50+
result = entry.to_dict(dtype="float32")
5551
assert result["entry_id"] == hashify("What is AI?")
5652
assert result["metadata"] == json.dumps({"author": "John"})
5753
assert result["prompt_vector"] == array_to_buffer([0.1, 0.2, 0.3], "float32")
@@ -112,9 +108,8 @@ def test_cache_entry_with_empty_optional_fields():
112108
prompt="What is AI?",
113109
response="AI is artificial intelligence.",
114110
prompt_vector=[0.1, 0.2, 0.3],
115-
dtype="bfloat16",
116111
)
117-
result = entry.to_dict()
112+
result = entry.to_dict(dtype="float32")
118113
assert "metadata" not in result
119114
assert "filters" not in result
120115

0 commit comments

Comments
 (0)