Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adds semantic cache scoped access and additional functionality #180

Merged
merged 15 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions redisvl/extensions/llmcache/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,3 @@ def store(
def hash_input(self, prompt: str):
"""Hashes the input using SHA256."""
return hashify(prompt)

def serialize(self, metadata: Dict[str, Any]) -> str:
"""Serlize the input into a string."""
return json.dumps(metadata)

def deserialize(self, metadata: str) -> Dict[str, Any]:
"""Deserialize the input from a string."""
return json.loads(metadata)
136 changes: 108 additions & 28 deletions redisvl/extensions/llmcache/semantic.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,53 @@
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

from redis import Redis

from redisvl.extensions.llmcache.base import BaseLLMCache
from redisvl.index import SearchIndex
from redisvl.query import RangeQuery
from redisvl.query.filter import FilterExpression, Tag
from redisvl.redis.utils import array_to_buffer
from redisvl.schema.schema import IndexSchema
from redisvl.schema import IndexSchema
from redisvl.utils.utils import current_timestamp, deserialize, serialize
from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer


class SemanticCacheIndexSchema(IndexSchema):

@classmethod
def from_params(cls, name: str, vector_dims: int):

return cls(
index={"name": name, "prefix": name}, # type: ignore
fields=[ # type: ignore
{"name": "prompt", "type": "text"},
{"name": "response", "type": "text"},
{"name": "inserted_at", "type": "numeric"},
{"name": "updated_at", "type": "numeric"},
{"name": "label", "type": "tag"},
{
"name": "prompt_vector",
"type": "vector",
"attrs": {
"dims": vector_dims,
"datatype": "float32",
"distance_metric": "cosine",
"algorithm": "flat",
},
},
],
)


class SemanticCache(BaseLLMCache):
"""Semantic Cache for Large Language Models."""

entry_id_field_name: str = "_id"
prompt_field_name: str = "prompt"
vector_field_name: str = "prompt_vector"
inserted_at_field_name: str = "inserted_at"
updated_at_field_name: str = "updated_at"
tag_field_name: str = "label"
response_field_name: str = "response"
metadata_field_name: str = "metadata"

Expand Down Expand Up @@ -69,27 +101,7 @@ def __init__(
model="sentence-transformers/all-mpnet-base-v2"
)

# build cache index schema
schema = IndexSchema.from_dict({"index": {"name": name, "prefix": prefix}})
# add fields
schema.add_fields(
[
{"name": self.prompt_field_name, "type": "text"},
{"name": self.response_field_name, "type": "text"},
{
"name": self.vector_field_name,
"type": "vector",
"attrs": {
"dims": vectorizer.dims,
"datatype": "float32",
"distance_metric": "cosine",
"algorithm": "flat",
},
},
]
)

# build search index
schema = SemanticCacheIndexSchema.from_params(name, vectorizer.dims)
self._index = SearchIndex(schema=schema)

# handle redis connection
Expand All @@ -103,12 +115,12 @@ def __init__(
self.entry_id_field_name,
self.prompt_field_name,
self.response_field_name,
self.tag_field_name,
self.vector_field_name,
self.metadata_field_name,
]
self.set_vectorizer(vectorizer)
self.set_threshold(distance_threshold)

self._index.create(overwrite=False)

@property
Expand Down Expand Up @@ -182,6 +194,14 @@ def delete(self) -> None:
index."""
self._index.delete(drop=True)

def drop(self, document_ids: Union[str, List[str]]) -> None:
"""Remove a specific entry or entries from the cache by it's ID.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So for this, are we expecting users to pass the full redis key? or just the id portion (without prefix)? I think the terminology we try to use for this throughout the library is id when we are referring to the part without the prefix. So maybe we just use ids instead of document_ids?


Args:
document_ids (Union[str, List[str]]): The document ID or IDs to remove from the cache.
"""
self._index.drop_keys(document_ids)

def _refresh_ttl(self, key: str) -> None:
"""Refresh the time-to-live for the specified key."""
if self._ttl:
Expand All @@ -195,7 +215,11 @@ def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]:
return self._vectorizer.embed(prompt)

def _search_cache(
self, vector: List[float], num_results: int, return_fields: Optional[List[str]]
self,
vector: List[float],
num_results: int,
return_fields: Optional[List[str]],
tag_filter: Optional[FilterExpression],
) -> List[Dict[str, Any]]:
"""Searches the semantic cache for similar prompt vectors and returns
the specified return fields for each cache hit."""
Expand All @@ -217,6 +241,8 @@ def _search_cache(
num_results=num_results,
return_score=True,
)
if tag_filter:
query.set_filter(tag_filter) # type: ignore

# Gather and return the cache hits
cache_hits: List[Dict[str, Any]] = self._index.query(query)
Expand All @@ -226,7 +252,7 @@ def _search_cache(
self._refresh_ttl(key)
# Check for metadata and deserialize
if self.metadata_field_name in hit:
hit[self.metadata_field_name] = self.deserialize(
hit[self.metadata_field_name] = deserialize(
hit[self.metadata_field_name]
)
return cache_hits
Expand All @@ -248,6 +274,7 @@ def check(
vector: Optional[List[float]] = None,
num_results: int = 1,
return_fields: Optional[List[str]] = None,
tag_filter: Optional[FilterExpression] = None,
) -> List[Dict[str, Any]]:
"""Checks the semantic cache for results similar to the specified prompt
or vector.
Expand All @@ -267,6 +294,8 @@ def check(
return_fields (Optional[List[str]], optional): The fields to include
in each returned result. If None, defaults to all available
fields in the cached entry.
tag_filter (Optional[FilterExpression]) : the tag filter to filter
results by. Default is None and full cache is searched.

Returns:
List[Dict[str, Any]]: A list of dicts containing the requested
Expand All @@ -291,7 +320,7 @@ def check(
self._check_vector_dims(vector)

# Check for cache hits by searching the cache
cache_hits = self._search_cache(vector, num_results, return_fields)
cache_hits = self._search_cache(vector, num_results, return_fields, tag_filter)
return cache_hits

def store(
Expand All @@ -300,6 +329,7 @@ def store(
response: str,
vector: Optional[List[float]] = None,
metadata: Optional[dict] = None,
tag: Optional[str] = None,
) -> str:
"""Stores the specified key-value pair in the cache along with metadata.

Expand All @@ -311,6 +341,8 @@ def store(
demand.
metadata (Optional[dict], optional): The optional metadata to cache
alongside the prompt and response. Defaults to None.
tag (Optional[str]): The optional tag to assign to the cache entry.
Defaults to None.

Returns:
str: The Redis key for the entries added to the semantic cache.
Expand All @@ -333,19 +365,67 @@ def store(
self._check_vector_dims(vector)

# Construct semantic cache payload
now = current_timestamp()
id_field = self.entry_id_field_name
payload = {
id_field: self.hash_input(prompt),
self.prompt_field_name: prompt,
self.response_field_name: response,
self.vector_field_name: array_to_buffer(vector),
self.inserted_at_field_name: now,
self.updated_at_field_name: now,
}
if metadata is not None:
if not isinstance(metadata, dict):
raise TypeError("If specified, cached metadata must be a dictionary.")
# Serialize the metadata dict and add to cache payload
payload[self.metadata_field_name] = self.serialize(metadata)
payload[self.metadata_field_name] = serialize(metadata)
if tag is not None:
payload[self.tag_field_name] = tag

# Load LLMCache entry with TTL
keys = self._index.load(data=[payload], ttl=self._ttl, id_field=id_field)
return keys[0]

def update(self, key: str, **kwargs) -> None:
"""Update specific fields within an existing cache entry. If no fields
are passed, then only the document TTL is refreshed.

Args:
key (str): the key of the document to update.
kwargs:

Raises:
ValueError if an incorrect mapping is provided as a kwarg.
TypeError if metadata is provided and not of type dict.

.. code-block:: python
key = cache.store('this is a prompt', 'this is a response')
cache.update(key, metadata={"hit_count": 1, "model_name": "Llama-2-7b"})
)
"""
if not kwargs:
self._refresh_ttl(key)
return

for _key, val in kwargs.items():
if _key not in {
self.prompt_field_name,
self.vector_field_name,
self.response_field_name,
self.tag_field_name,
self.metadata_field_name,
}:
raise ValueError(f" {key} is not a valid field within document")

# Check for metadata and deserialize
if _key == self.metadata_field_name:
if isinstance(val, dict):
kwargs[_key] = serialize(val)
else:
raise TypeError(
"If specified, cached metadata must be a dictionary."
)
kwargs.update({self.updated_at_field_name: current_timestamp()})
self._index.client.hset(key, mapping=kwargs) # type: ignore
self._refresh_ttl(key)
2 changes: 1 addition & 1 deletion redisvl/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,6 @@ def serialize(data: Dict[str, Any]) -> str:
return json.dumps(data)


def deserialize(self, data: str) -> Dict[str, Any]:
def deserialize(data: str) -> Dict[str, Any]:
"""Deserialize the input from a string."""
return json.loads(data)
Loading
Loading