1515from redisvl .utils .utils import current_timestamp , serialize , validate_vector_dims
1616from redisvl .utils .vectorize import BaseVectorizer , HFTextVectorizer
1717
18+ VECTOR_FIELD_NAME = "prompt_vector" ###
19+
1820
1921class SemanticCache (BaseLLMCache ):
2022 """Semantic Cache for Large Language Models."""
@@ -23,7 +25,7 @@ class SemanticCache(BaseLLMCache):
2325 entry_id_field_name : str = "entry_id"
2426 prompt_field_name : str = "prompt"
2527 response_field_name : str = "response"
26- vector_field_name : str = "prompt_vector"
28+ ### vector_field_name: str = "prompt_vector"
2729 inserted_at_field_name : str = "inserted_at"
2830 updated_at_field_name : str = "updated_at"
2931 metadata_field_name : str = "metadata"
@@ -136,9 +138,10 @@ def __init__(
136138
137139 validate_vector_dims (
138140 vectorizer .dims ,
139- self ._index .schema .fields [self . vector_field_name ].attrs .dims , # type: ignore
141+ self ._index .schema .fields [VECTOR_FIELD_NAME ].attrs .dims , # type: ignore
140142 )
141143 self ._vectorizer = vectorizer
144+ self ._dtype = self .index .schema .fields [VECTOR_FIELD_NAME ].attrs .datatype # type: ignore[union-attr]
142145
143146 def _modify_schema (
144147 self ,
@@ -290,8 +293,7 @@ def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]:
290293 if not isinstance (prompt , str ):
291294 raise TypeError ("Prompt must be a string." )
292295
293- dtype = self .index .schema .fields [self .vector_field_name ].attrs .datatype # type: ignore[union-attr]
294- return self ._vectorizer .embed (prompt , dtype = dtype )
296+ return self ._vectorizer .embed (prompt , dtype = self ._dtype )
295297
296298 async def _avectorize_prompt (self , prompt : Optional [str ]) -> List [float ]:
297299 """Converts a text prompt to its vector representation using the
@@ -304,7 +306,7 @@ async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]:
304306 def _check_vector_dims (self , vector : List [float ]):
305307 """Checks the size of the provided vector and raises an error if it
306308 doesn't match the search index vector dimensions."""
307- schema_vector_dims = self ._index .schema .fields [self . vector_field_name ].attrs .dims # type: ignore
309+ schema_vector_dims = self ._index .schema .fields [VECTOR_FIELD_NAME ].attrs .dims # type: ignore
308310 validate_vector_dims (len (vector ), schema_vector_dims )
309311
310312 def check (
@@ -367,13 +369,13 @@ def check(
367369
368370 query = RangeQuery (
369371 vector = vector ,
370- vector_field_name = self . vector_field_name ,
372+ vector_field_name = VECTOR_FIELD_NAME ,
371373 return_fields = self .return_fields ,
372374 distance_threshold = distance_threshold ,
373375 num_results = num_results ,
374376 return_score = True ,
375377 filter_expression = filter_expression ,
376- dtype = self .index . schema . fields [ self . vector_field_name ]. attrs . datatype , # type: ignore[union-attr]
378+ dtype = self ._dtype ,
377379 )
378380
379381 # Search the cache!
@@ -449,7 +451,7 @@ async def acheck(
449451
450452 query = RangeQuery (
451453 vector = vector ,
452- vector_field_name = self . vector_field_name ,
454+ vector_field_name = VECTOR_FIELD_NAME ,
453455 return_fields = self .return_fields ,
454456 distance_threshold = distance_threshold ,
455457 num_results = num_results ,
@@ -539,13 +541,12 @@ def store(
539541 prompt_vector = vector ,
540542 metadata = metadata ,
541543 filters = filters ,
542- dtype = self .index .schema .fields [self .vector_field_name ].attrs .datatype , # type: ignore[union-attr]
543544 )
544545
545546 # Load cache entry with TTL
546547 ttl = ttl or self ._ttl
547548 keys = self ._index .load (
548- data = [cache_entry .to_dict ()],
549+ data = [cache_entry .to_dict (self . _dtype )],
549550 ttl = ttl ,
550551 id_field = self .entry_id_field_name ,
551552 )
@@ -604,13 +605,12 @@ async def astore(
604605 prompt_vector = vector ,
605606 metadata = metadata ,
606607 filters = filters ,
607- dtype = self .index .schema .fields [self .vector_field_name ].attrs .datatype , # type: ignore[union-attr]
608608 )
609609
610610 # Load cache entry with TTL
611611 ttl = ttl or self ._ttl
612612 keys = await aindex .load (
613- data = [cache_entry .to_dict ()],
613+ data = [cache_entry .to_dict (self . _dtype )],
614614 ttl = ttl ,
615615 id_field = self .entry_id_field_name ,
616616 )
0 commit comments