Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 41 additions & 7 deletions nemoguardrails/llm/cache/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,20 @@ class LLMStatsDict(TypedDict):
completion_tokens: int


class LLMMetadataDict(TypedDict):
model_name: str
provider_name: str


class LLMCacheData(TypedDict):
stats: Optional[LLMStatsDict]
metadata: Optional[LLMMetadataDict]


class CacheEntry(TypedDict):
result: dict
llm_stats: Optional[LLMStatsDict]
llm_metadata: Optional[LLMMetadataDict]


def create_normalized_cache_key(
Expand Down Expand Up @@ -95,27 +106,27 @@ def create_normalized_cache_key(


def restore_llm_stats_from_cache(
cached_stats: LLMStatsDict, cache_read_duration: float
cached_stats: LLMStatsDict, cache_read_duration_s: float
) -> None:
llm_stats = llm_stats_var.get()
if llm_stats is None:
llm_stats = LLMStats()
llm_stats_var.set(llm_stats)

llm_stats.inc("total_calls")
llm_stats.inc("total_time", cache_read_duration)
llm_stats.inc("total_time", cache_read_duration_s)
llm_stats.inc("total_tokens", cached_stats.get("total_tokens", 0))
llm_stats.inc("total_prompt_tokens", cached_stats.get("prompt_tokens", 0))
llm_stats.inc("total_completion_tokens", cached_stats.get("completion_tokens", 0))

llm_call_info = llm_call_info_var.get()
if llm_call_info:
llm_call_info.duration = cache_read_duration
llm_call_info.duration = cache_read_duration_s
llm_call_info.total_tokens = cached_stats.get("total_tokens", 0)
llm_call_info.prompt_tokens = cached_stats.get("prompt_tokens", 0)
llm_call_info.completion_tokens = cached_stats.get("completion_tokens", 0)
llm_call_info.from_cache = True
llm_call_info.started_at = time() - cache_read_duration
llm_call_info.started_at = time() - cache_read_duration_s
llm_call_info.finished_at = time()


Expand All @@ -130,20 +141,43 @@ def extract_llm_stats_for_cache() -> Optional[LLMStatsDict]:
return None


def extract_llm_metadata_for_cache() -> Optional[LLMMetadataDict]:
llm_call_info = llm_call_info_var.get()
if llm_call_info:
return {
"model_name": llm_call_info.llm_model_name or "unknown",
"provider_name": llm_call_info.llm_provider_name or "unknown",
}
return None


def restore_llm_metadata_from_cache(cached_metadata: LLMMetadataDict) -> None:
llm_call_info = llm_call_info_var.get()
if llm_call_info:
llm_call_info.llm_model_name = cached_metadata.get("model_name", "unknown")
llm_call_info.llm_provider_name = cached_metadata.get(
"provider_name", "unknown"
)


def get_from_cache_and_restore_stats(
cache: "CacheInterface", cache_key: str
) -> Optional[dict]:
cached_entry = cache.get(cache_key)
if cached_entry is None:
return None

cache_read_start = time()
cache_read_start_s = time()
final_result = cached_entry["result"]
cached_stats = cached_entry.get("llm_stats")
cache_read_duration = time() - cache_read_start
cached_metadata = cached_entry.get("llm_metadata")
cache_read_duration_s = time() - cache_read_start_s

if cached_stats:
restore_llm_stats_from_cache(cached_stats, cache_read_duration)
restore_llm_stats_from_cache(cached_stats, cache_read_duration_s)

if cached_metadata:
restore_llm_metadata_from_cache(cached_metadata)

processing_log = processing_log_var.get()
if processing_log is not None:
Expand Down
148 changes: 127 additions & 21 deletions tests/test_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest.mock import MagicMock

import pytest

from nemoguardrails.context import llm_call_info_var, llm_stats_var
from nemoguardrails.llm.cache.lfu import LFUCache
from nemoguardrails.llm.cache.utils import (
CacheEntry,
LLMMetadataDict,
LLMStatsDict,
create_normalized_cache_key,
extract_llm_metadata_for_cache,
extract_llm_stats_for_cache,
get_from_cache_and_restore_stats,
restore_llm_metadata_from_cache,
restore_llm_stats_from_cache,
)
from nemoguardrails.logging.explain import LLMCallInfo
Expand All @@ -31,6 +35,12 @@


class TestCacheUtils:
@pytest.fixture(autouse=True)
def isolated_llm_call_info_var(self):
llm_call_info_var.set(None)
yield
llm_call_info_var.set(None)

def test_create_normalized_cache_key_returns_sha256_hash(self):
key = create_normalized_cache_key("Hello world")
assert len(key) == 64
Expand Down Expand Up @@ -97,10 +107,10 @@ def test_create_normalized_cache_key_different_for_different_input(

def test_create_normalized_cache_key_invalid_type_raises_error(self):
with pytest.raises(TypeError, match="Invalid type for prompt: int"):
create_normalized_cache_key(123)
create_normalized_cache_key(123) # type: ignore

with pytest.raises(TypeError, match="Invalid type for prompt: dict"):
create_normalized_cache_key({"key": "value"})
create_normalized_cache_key({"key": "value"}) # type: ignore

def test_create_normalized_cache_key_list_of_dicts(self):
messages = [
Expand Down Expand Up @@ -129,25 +139,24 @@ def test_create_normalized_cache_key_invalid_list_raises_error(self):
TypeError,
match="All elements in prompt list must be dictionaries",
):
create_normalized_cache_key(["hello", "world"])
create_normalized_cache_key(["hello", "world"]) # type: ignore

with pytest.raises(
TypeError,
match="All elements in prompt list must be dictionaries",
):
create_normalized_cache_key([{"key": "value"}, "test"])
create_normalized_cache_key([{"key": "value"}, "test"]) # type: ignore

with pytest.raises(
TypeError,
match="All elements in prompt list must be dictionaries",
):
create_normalized_cache_key([123, 456])
create_normalized_cache_key([123, 456]) # type: ignore

def test_extract_llm_stats_for_cache_with_llm_call_info(self):
llm_call_info = LLMCallInfo(task="test_task")
llm_call_info.total_tokens = 100
llm_call_info.prompt_tokens = 50
llm_call_info.completion_tokens = 50
llm_call_info = LLMCallInfo(
task="test_task", total_tokens=100, prompt_tokens=50, completion_tokens=50
)
llm_call_info_var.set(llm_call_info)

stats = extract_llm_stats_for_cache()
Expand All @@ -167,10 +176,12 @@ def test_extract_llm_stats_for_cache_without_llm_call_info(self):
assert stats is None

def test_extract_llm_stats_for_cache_with_none_values(self):
llm_call_info = LLMCallInfo(task="test_task")
llm_call_info.total_tokens = None
llm_call_info.prompt_tokens = None
llm_call_info.completion_tokens = None
llm_call_info = LLMCallInfo(
task="test_task",
total_tokens=None,
prompt_tokens=None,
completion_tokens=None,
)
llm_call_info_var.set(llm_call_info)

stats = extract_llm_stats_for_cache()
Expand All @@ -186,13 +197,13 @@ def test_restore_llm_stats_from_cache_creates_new_llm_stats(self):
llm_stats_var.set(None)
llm_call_info_var.set(None)

cached_stats = {
cached_stats: LLMStatsDict = {
"total_tokens": 100,
"prompt_tokens": 50,
"completion_tokens": 50,
}

restore_llm_stats_from_cache(cached_stats, cache_read_duration=0.01)
restore_llm_stats_from_cache(cached_stats, cache_read_duration_s=0.01)

llm_stats = llm_stats_var.get()
assert llm_stats is not None
Expand All @@ -211,15 +222,16 @@ def test_restore_llm_stats_from_cache_updates_existing_llm_stats(self):
llm_stats.inc("total_tokens", 200)
llm_stats_var.set(llm_stats)

cached_stats = {
cached_stats: LLMStatsDict = {
"total_tokens": 100,
"prompt_tokens": 50,
"completion_tokens": 50,
}

restore_llm_stats_from_cache(cached_stats, cache_read_duration=0.5)
restore_llm_stats_from_cache(cached_stats, cache_read_duration_s=0.5)

llm_stats = llm_stats_var.get()
assert llm_stats is not None
assert llm_stats.get_stat("total_calls") == 6
assert llm_stats.get_stat("total_time") == 1.5
assert llm_stats.get_stat("total_tokens") == 300
Expand All @@ -231,13 +243,13 @@ def test_restore_llm_stats_from_cache_updates_llm_call_info(self):
llm_call_info_var.set(llm_call_info)
llm_stats_var.set(None)

cached_stats = {
cached_stats: LLMStatsDict = {
"total_tokens": 100,
"prompt_tokens": 50,
"completion_tokens": 50,
}

restore_llm_stats_from_cache(cached_stats, cache_read_duration=0.02)
restore_llm_stats_from_cache(cached_stats, cache_read_duration_s=0.02)

updated_info = llm_call_info_var.get()
assert updated_info is not None
Expand Down Expand Up @@ -273,6 +285,7 @@ def test_get_from_cache_and_restore_stats_cache_hit(self):
"prompt_tokens": 50,
"completion_tokens": 50,
},
"llm_metadata": None,
}
cache.put("test_key", cache_entry)

Expand All @@ -291,6 +304,7 @@ def test_get_from_cache_and_restore_stats_cache_hit(self):
assert llm_stats.get_stat("total_tokens") == 100

updated_info = llm_call_info_var.get()
assert updated_info is not None
assert updated_info.from_cache is True

llm_call_info_var.set(None)
Expand All @@ -301,6 +315,7 @@ def test_get_from_cache_and_restore_stats_without_llm_stats(self):
cache_entry = {
"result": {"allowed": False, "policy_violations": ["policy1"]},
"llm_stats": None,
"llm_metadata": None,
}
cache.put("test_key", cache_entry)

Expand All @@ -324,6 +339,7 @@ def test_get_from_cache_and_restore_stats_with_processing_log(self):
"prompt_tokens": 60,
"completion_tokens": 20,
},
"llm_metadata": None,
}
cache.put("test_key", cache_entry)

Expand All @@ -332,14 +348,15 @@ def test_get_from_cache_and_restore_stats_with_processing_log(self):
llm_stats_var.set(None)

processing_log = []
processing_log_var.set(processing_log)
processing_log_var.set(processing_log) # type: ignore

result = get_from_cache_and_restore_stats(cache, "test_key")

assert result is not None
assert result == {"allowed": True, "policy_violations": []}

retrieved_log = processing_log_var.get()
assert retrieved_log is not None
assert len(retrieved_log) == 1
assert retrieved_log[0]["type"] == "llm_call_info"
assert "timestamp" in retrieved_log[0]
Expand All @@ -359,6 +376,7 @@ def test_get_from_cache_and_restore_stats_without_processing_log(self):
"prompt_tokens": 30,
"completion_tokens": 20,
},
"llm_metadata": None,
}
cache.put("test_key", cache_entry)

Expand All @@ -374,3 +392,91 @@ def test_get_from_cache_and_restore_stats_without_processing_log(self):

llm_call_info_var.set(None)
llm_stats_var.set(None)

def test_extract_llm_metadata_for_cache_with_model_info(self):
llm_call_info = LLMCallInfo(
task="test_task", llm_model_name="gpt-4", llm_provider_name="openai"
)
llm_call_info_var.set(llm_call_info)

metadata = extract_llm_metadata_for_cache()

assert metadata is not None
assert metadata["model_name"] == "gpt-4"
assert metadata["provider_name"] == "openai"

llm_call_info_var.set(None)

def test_extract_llm_metadata_for_cache_without_llm_call_info(self):
llm_call_info_var.set(None)

metadata = extract_llm_metadata_for_cache()

assert metadata is None

def test_extract_llm_metadata_for_cache_with_defaults(self):
llm_call_info = LLMCallInfo(task="test_task")
llm_call_info_var.set(llm_call_info)

metadata = extract_llm_metadata_for_cache()

assert metadata is not None
assert metadata["model_name"] == "unknown"
assert metadata["provider_name"] == "unknown"

llm_call_info_var.set(None)

def test_restore_llm_metadata_from_cache(self):
llm_call_info = LLMCallInfo(task="test_task")
llm_call_info_var.set(llm_call_info)

cached_metadata: LLMMetadataDict = {
"model_name": "nvidia/llama-3.1-nemoguard-8b-content-safety",
"provider_name": "nim",
}

restore_llm_metadata_from_cache(cached_metadata)

updated_info = llm_call_info_var.get()
assert updated_info is not None
assert (
updated_info.llm_model_name
== "nvidia/llama-3.1-nemoguard-8b-content-safety"
)
assert updated_info.llm_provider_name == "nim"

llm_call_info_var.set(None)

def test_get_from_cache_and_restore_stats_with_metadata(self):
cache = LFUCache(maxsize=10)
cache_entry: CacheEntry = {
"result": {"allowed": True, "policy_violations": []},
"llm_stats": {
"total_tokens": 100,
"prompt_tokens": 50,
"completion_tokens": 50,
},
"llm_metadata": {
"model_name": "gpt-4-turbo",
"provider_name": "openai",
},
}
cache.put("test_key", cache_entry)

llm_call_info = LLMCallInfo(task="test_task")
llm_call_info_var.set(llm_call_info)
llm_stats_var.set(None)

result = get_from_cache_and_restore_stats(cache, "test_key")

assert result is not None
assert result == {"allowed": True, "policy_violations": []}

updated_info = llm_call_info_var.get()
assert updated_info is not None
assert updated_info.from_cache is True
assert updated_info.llm_model_name == "gpt-4-turbo"
assert updated_info.llm_provider_name == "openai"

llm_call_info_var.set(None)
llm_stats_var.set(None)