diff --git a/nemoguardrails/llm/cache/utils.py b/nemoguardrails/llm/cache/utils.py index 15a6328a6..5975a801a 100644 --- a/nemoguardrails/llm/cache/utils.py +++ b/nemoguardrails/llm/cache/utils.py @@ -35,9 +35,20 @@ class LLMStatsDict(TypedDict): completion_tokens: int +class LLMMetadataDict(TypedDict): + model_name: str + provider_name: str + + +class LLMCacheData(TypedDict): + stats: Optional[LLMStatsDict] + metadata: Optional[LLMMetadataDict] + + class CacheEntry(TypedDict): result: dict llm_stats: Optional[LLMStatsDict] + llm_metadata: Optional[LLMMetadataDict] def create_normalized_cache_key( @@ -95,7 +106,7 @@ def create_normalized_cache_key( def restore_llm_stats_from_cache( - cached_stats: LLMStatsDict, cache_read_duration: float + cached_stats: LLMStatsDict, cache_read_duration_s: float ) -> None: llm_stats = llm_stats_var.get() if llm_stats is None: @@ -103,19 +114,19 @@ def restore_llm_stats_from_cache( llm_stats_var.set(llm_stats) llm_stats.inc("total_calls") - llm_stats.inc("total_time", cache_read_duration) + llm_stats.inc("total_time", cache_read_duration_s) llm_stats.inc("total_tokens", cached_stats.get("total_tokens", 0)) llm_stats.inc("total_prompt_tokens", cached_stats.get("prompt_tokens", 0)) llm_stats.inc("total_completion_tokens", cached_stats.get("completion_tokens", 0)) llm_call_info = llm_call_info_var.get() if llm_call_info: - llm_call_info.duration = cache_read_duration + llm_call_info.duration = cache_read_duration_s llm_call_info.total_tokens = cached_stats.get("total_tokens", 0) llm_call_info.prompt_tokens = cached_stats.get("prompt_tokens", 0) llm_call_info.completion_tokens = cached_stats.get("completion_tokens", 0) llm_call_info.from_cache = True - llm_call_info.started_at = time() - cache_read_duration + llm_call_info.started_at = time() - cache_read_duration_s llm_call_info.finished_at = time() @@ -130,6 +141,25 @@ def extract_llm_stats_for_cache() -> Optional[LLMStatsDict]: return None +def extract_llm_metadata_for_cache() -> Optional[LLMMetadataDict]: + llm_call_info = llm_call_info_var.get() + if llm_call_info: + return { + "model_name": llm_call_info.llm_model_name or "unknown", + "provider_name": llm_call_info.llm_provider_name or "unknown", + } + return None + + +def restore_llm_metadata_from_cache(cached_metadata: LLMMetadataDict) -> None: + llm_call_info = llm_call_info_var.get() + if llm_call_info: + llm_call_info.llm_model_name = cached_metadata.get("model_name", "unknown") + llm_call_info.llm_provider_name = cached_metadata.get( + "provider_name", "unknown" + ) + + def get_from_cache_and_restore_stats( cache: "CacheInterface", cache_key: str ) -> Optional[dict]: @@ -137,13 +167,17 @@ def get_from_cache_and_restore_stats( if cached_entry is None: return None - cache_read_start = time() + cache_read_start_s = time() final_result = cached_entry["result"] cached_stats = cached_entry.get("llm_stats") - cache_read_duration = time() - cache_read_start + cached_metadata = cached_entry.get("llm_metadata") + cache_read_duration_s = time() - cache_read_start_s if cached_stats: - restore_llm_stats_from_cache(cached_stats, cache_read_duration) + restore_llm_stats_from_cache(cached_stats, cache_read_duration_s) + + if cached_metadata: + restore_llm_metadata_from_cache(cached_metadata) processing_log = processing_log_var.get() if processing_log is not None: diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py index 7c7b0fdca..464f4cfe7 100644 --- a/tests/test_cache_utils.py +++ b/tests/test_cache_utils.py @@ -13,16 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import MagicMock import pytest from nemoguardrails.context import llm_call_info_var, llm_stats_var from nemoguardrails.llm.cache.lfu import LFUCache from nemoguardrails.llm.cache.utils import ( + CacheEntry, + LLMMetadataDict, + LLMStatsDict, create_normalized_cache_key, + extract_llm_metadata_for_cache, extract_llm_stats_for_cache, get_from_cache_and_restore_stats, + restore_llm_metadata_from_cache, restore_llm_stats_from_cache, ) from nemoguardrails.logging.explain import LLMCallInfo @@ -31,6 +35,12 @@ class TestCacheUtils: + @pytest.fixture(autouse=True) + def isolated_llm_call_info_var(self): + llm_call_info_var.set(None) + yield + llm_call_info_var.set(None) + def test_create_normalized_cache_key_returns_sha256_hash(self): key = create_normalized_cache_key("Hello world") assert len(key) == 64 @@ -97,10 +107,10 @@ def test_create_normalized_cache_key_different_for_different_input( def test_create_normalized_cache_key_invalid_type_raises_error(self): with pytest.raises(TypeError, match="Invalid type for prompt: int"): - create_normalized_cache_key(123) + create_normalized_cache_key(123) # type: ignore with pytest.raises(TypeError, match="Invalid type for prompt: dict"): - create_normalized_cache_key({"key": "value"}) + create_normalized_cache_key({"key": "value"}) # type: ignore def test_create_normalized_cache_key_list_of_dicts(self): messages = [ @@ -129,25 +139,24 @@ def test_create_normalized_cache_key_invalid_list_raises_error(self): TypeError, match="All elements in prompt list must be dictionaries", ): - create_normalized_cache_key(["hello", "world"]) + create_normalized_cache_key(["hello", "world"]) # type: ignore with pytest.raises( TypeError, match="All elements in prompt list must be dictionaries", ): - create_normalized_cache_key([{"key": "value"}, "test"]) + create_normalized_cache_key([{"key": "value"}, "test"]) # type: ignore with pytest.raises( TypeError, match="All elements in prompt list must be dictionaries", ): - create_normalized_cache_key([123, 456]) + create_normalized_cache_key([123, 456]) # type: ignore def test_extract_llm_stats_for_cache_with_llm_call_info(self): - llm_call_info = LLMCallInfo(task="test_task") - llm_call_info.total_tokens = 100 - llm_call_info.prompt_tokens = 50 - llm_call_info.completion_tokens = 50 + llm_call_info = LLMCallInfo( + task="test_task", total_tokens=100, prompt_tokens=50, completion_tokens=50 + ) llm_call_info_var.set(llm_call_info) stats = extract_llm_stats_for_cache() @@ -167,10 +176,12 @@ def test_extract_llm_stats_for_cache_without_llm_call_info(self): assert stats is None def test_extract_llm_stats_for_cache_with_none_values(self): - llm_call_info = LLMCallInfo(task="test_task") - llm_call_info.total_tokens = None - llm_call_info.prompt_tokens = None - llm_call_info.completion_tokens = None + llm_call_info = LLMCallInfo( + task="test_task", + total_tokens=None, + prompt_tokens=None, + completion_tokens=None, + ) llm_call_info_var.set(llm_call_info) stats = extract_llm_stats_for_cache() @@ -186,13 +197,13 @@ def test_restore_llm_stats_from_cache_creates_new_llm_stats(self): llm_stats_var.set(None) llm_call_info_var.set(None) - cached_stats = { + cached_stats: LLMStatsDict = { "total_tokens": 100, "prompt_tokens": 50, "completion_tokens": 50, } - restore_llm_stats_from_cache(cached_stats, cache_read_duration=0.01) + restore_llm_stats_from_cache(cached_stats, cache_read_duration_s=0.01) llm_stats = llm_stats_var.get() assert llm_stats is not None @@ -211,15 +222,16 @@ def test_restore_llm_stats_from_cache_updates_existing_llm_stats(self): llm_stats.inc("total_tokens", 200) llm_stats_var.set(llm_stats) - cached_stats = { + cached_stats: LLMStatsDict = { "total_tokens": 100, "prompt_tokens": 50, "completion_tokens": 50, } - restore_llm_stats_from_cache(cached_stats, cache_read_duration=0.5) + restore_llm_stats_from_cache(cached_stats, cache_read_duration_s=0.5) llm_stats = llm_stats_var.get() + assert llm_stats is not None assert llm_stats.get_stat("total_calls") == 6 assert llm_stats.get_stat("total_time") == 1.5 assert llm_stats.get_stat("total_tokens") == 300 @@ -231,13 +243,13 @@ def test_restore_llm_stats_from_cache_updates_llm_call_info(self): llm_call_info_var.set(llm_call_info) llm_stats_var.set(None) - cached_stats = { + cached_stats: LLMStatsDict = { "total_tokens": 100, "prompt_tokens": 50, "completion_tokens": 50, } - restore_llm_stats_from_cache(cached_stats, cache_read_duration=0.02) + restore_llm_stats_from_cache(cached_stats, cache_read_duration_s=0.02) updated_info = llm_call_info_var.get() assert updated_info is not None @@ -273,6 +285,7 @@ def test_get_from_cache_and_restore_stats_cache_hit(self): "prompt_tokens": 50, "completion_tokens": 50, }, + "llm_metadata": None, } cache.put("test_key", cache_entry) @@ -291,6 +304,7 @@ def test_get_from_cache_and_restore_stats_cache_hit(self): assert llm_stats.get_stat("total_tokens") == 100 updated_info = llm_call_info_var.get() + assert updated_info is not None assert updated_info.from_cache is True llm_call_info_var.set(None) @@ -301,6 +315,7 @@ def test_get_from_cache_and_restore_stats_without_llm_stats(self): cache_entry = { "result": {"allowed": False, "policy_violations": ["policy1"]}, "llm_stats": None, + "llm_metadata": None, } cache.put("test_key", cache_entry) @@ -324,6 +339,7 @@ def test_get_from_cache_and_restore_stats_with_processing_log(self): "prompt_tokens": 60, "completion_tokens": 20, }, + "llm_metadata": None, } cache.put("test_key", cache_entry) @@ -332,7 +348,7 @@ def test_get_from_cache_and_restore_stats_with_processing_log(self): llm_stats_var.set(None) processing_log = [] - processing_log_var.set(processing_log) + processing_log_var.set(processing_log) # type: ignore result = get_from_cache_and_restore_stats(cache, "test_key") @@ -340,6 +356,7 @@ def test_get_from_cache_and_restore_stats_with_processing_log(self): assert result == {"allowed": True, "policy_violations": []} retrieved_log = processing_log_var.get() + assert retrieved_log is not None assert len(retrieved_log) == 1 assert retrieved_log[0]["type"] == "llm_call_info" assert "timestamp" in retrieved_log[0] @@ -359,6 +376,7 @@ def test_get_from_cache_and_restore_stats_without_processing_log(self): "prompt_tokens": 30, "completion_tokens": 20, }, + "llm_metadata": None, } cache.put("test_key", cache_entry) @@ -374,3 +392,91 @@ def test_get_from_cache_and_restore_stats_without_processing_log(self): llm_call_info_var.set(None) llm_stats_var.set(None) + + def test_extract_llm_metadata_for_cache_with_model_info(self): + llm_call_info = LLMCallInfo( + task="test_task", llm_model_name="gpt-4", llm_provider_name="openai" + ) + llm_call_info_var.set(llm_call_info) + + metadata = extract_llm_metadata_for_cache() + + assert metadata is not None + assert metadata["model_name"] == "gpt-4" + assert metadata["provider_name"] == "openai" + + llm_call_info_var.set(None) + + def test_extract_llm_metadata_for_cache_without_llm_call_info(self): + llm_call_info_var.set(None) + + metadata = extract_llm_metadata_for_cache() + + assert metadata is None + + def test_extract_llm_metadata_for_cache_with_defaults(self): + llm_call_info = LLMCallInfo(task="test_task") + llm_call_info_var.set(llm_call_info) + + metadata = extract_llm_metadata_for_cache() + + assert metadata is not None + assert metadata["model_name"] == "unknown" + assert metadata["provider_name"] == "unknown" + + llm_call_info_var.set(None) + + def test_restore_llm_metadata_from_cache(self): + llm_call_info = LLMCallInfo(task="test_task") + llm_call_info_var.set(llm_call_info) + + cached_metadata: LLMMetadataDict = { + "model_name": "nvidia/llama-3.1-nemoguard-8b-content-safety", + "provider_name": "nim", + } + + restore_llm_metadata_from_cache(cached_metadata) + + updated_info = llm_call_info_var.get() + assert updated_info is not None + assert ( + updated_info.llm_model_name + == "nvidia/llama-3.1-nemoguard-8b-content-safety" + ) + assert updated_info.llm_provider_name == "nim" + + llm_call_info_var.set(None) + + def test_get_from_cache_and_restore_stats_with_metadata(self): + cache = LFUCache(maxsize=10) + cache_entry: CacheEntry = { + "result": {"allowed": True, "policy_violations": []}, + "llm_stats": { + "total_tokens": 100, + "prompt_tokens": 50, + "completion_tokens": 50, + }, + "llm_metadata": { + "model_name": "gpt-4-turbo", + "provider_name": "openai", + }, + } + cache.put("test_key", cache_entry) + + llm_call_info = LLMCallInfo(task="test_task") + llm_call_info_var.set(llm_call_info) + llm_stats_var.set(None) + + result = get_from_cache_and_restore_stats(cache, "test_key") + + assert result is not None + assert result == {"allowed": True, "policy_violations": []} + + updated_info = llm_call_info_var.get() + assert updated_info is not None + assert updated_info.from_cache is True + assert updated_info.llm_model_name == "gpt-4-turbo" + assert updated_info.llm_provider_name == "openai" + + llm_call_info_var.set(None) + llm_stats_var.set(None)