Skip to content

Commit b05cac4

Browse files
committed
feat(cache): add LLM metadata caching for model and provider information
Extends the cache system to store and restore LLM metadata (model name and provider name) alongside cache entries. This allows cached results to maintain provenance information about which model and provider generated the original response. - Added LLMMetadataDict and LLMCacheData TypedDict definitions for type safety - Extended CacheEntry to include optional llm_metadata field - Implemented extract_llm_metadata_for_cache() to capture model and provider info from context - Implemented restore_llm_metadata_from_cache() to restore metadata when retrieving cached results - Updated get_from_cache_and_restore_stats() to handle metadata extraction and restoration - Added comprehensive test coverage for metadata caching functionalit
1 parent 053cd1c commit b05cac4

File tree

2 files changed

+143
-10
lines changed

2 files changed

+143
-10
lines changed

nemoguardrails/llm/cache/utils.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,20 @@ class LLMStatsDict(TypedDict):
3535
completion_tokens: int
3636

3737

38+
class LLMMetadataDict(TypedDict):
39+
model_name: str
40+
provider_name: str
41+
42+
43+
class LLMCacheData(TypedDict):
44+
stats: Optional[LLMStatsDict]
45+
metadata: Optional[LLMMetadataDict]
46+
47+
3848
class CacheEntry(TypedDict):
3949
result: dict
4050
llm_stats: Optional[LLMStatsDict]
51+
llm_metadata: Optional[LLMMetadataDict]
4152

4253

4354
def create_normalized_cache_key(
@@ -130,6 +141,25 @@ def extract_llm_stats_for_cache() -> Optional[LLMStatsDict]:
130141
return None
131142

132143

144+
def extract_llm_metadata_for_cache() -> Optional[LLMMetadataDict]:
145+
llm_call_info = llm_call_info_var.get()
146+
if llm_call_info:
147+
return {
148+
"model_name": llm_call_info.llm_model_name or "unknown",
149+
"provider_name": llm_call_info.llm_provider_name or "unknown",
150+
}
151+
return None
152+
153+
154+
def restore_llm_metadata_from_cache(cached_metadata: LLMMetadataDict) -> None:
155+
llm_call_info = llm_call_info_var.get()
156+
if llm_call_info:
157+
llm_call_info.llm_model_name = cached_metadata.get("model_name", "unknown")
158+
llm_call_info.llm_provider_name = cached_metadata.get(
159+
"provider_name", "unknown"
160+
)
161+
162+
133163
def get_from_cache_and_restore_stats(
134164
cache: "CacheInterface", cache_key: str
135165
) -> Optional[dict]:
@@ -140,11 +170,15 @@ def get_from_cache_and_restore_stats(
140170
cache_read_start = time()
141171
final_result = cached_entry["result"]
142172
cached_stats = cached_entry.get("llm_stats")
173+
cached_metadata = cached_entry.get("llm_metadata")
143174
cache_read_duration = time() - cache_read_start
144175

145176
if cached_stats:
146177
restore_llm_stats_from_cache(cached_stats, cache_read_duration)
147178

179+
if cached_metadata:
180+
restore_llm_metadata_from_cache(cached_metadata)
181+
148182
processing_log = processing_log_var.get()
149183
if processing_log is not None:
150184
llm_call_info = llm_call_info_var.get()

tests/test_cache_utils.py

Lines changed: 109 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,20 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from unittest.mock import MagicMock
1716

1817
import pytest
1918

2019
from nemoguardrails.context import llm_call_info_var, llm_stats_var
2120
from nemoguardrails.llm.cache.lfu import LFUCache
2221
from nemoguardrails.llm.cache.utils import (
22+
CacheEntry,
23+
LLMMetadataDict,
24+
LLMStatsDict,
2325
create_normalized_cache_key,
26+
extract_llm_metadata_for_cache,
2427
extract_llm_stats_for_cache,
2528
get_from_cache_and_restore_stats,
29+
restore_llm_metadata_from_cache,
2630
restore_llm_stats_from_cache,
2731
)
2832
from nemoguardrails.logging.explain import LLMCallInfo
@@ -97,10 +101,10 @@ def test_create_normalized_cache_key_different_for_different_input(
97101

98102
def test_create_normalized_cache_key_invalid_type_raises_error(self):
99103
with pytest.raises(TypeError, match="Invalid type for prompt: int"):
100-
create_normalized_cache_key(123)
104+
create_normalized_cache_key(123) # type: ignore
101105

102106
with pytest.raises(TypeError, match="Invalid type for prompt: dict"):
103-
create_normalized_cache_key({"key": "value"})
107+
create_normalized_cache_key({"key": "value"}) # type: ignore
104108

105109
def test_create_normalized_cache_key_list_of_dicts(self):
106110
messages = [
@@ -129,19 +133,19 @@ def test_create_normalized_cache_key_invalid_list_raises_error(self):
129133
TypeError,
130134
match="All elements in prompt list must be dictionaries",
131135
):
132-
create_normalized_cache_key(["hello", "world"])
136+
create_normalized_cache_key(["hello", "world"]) # type: ignore
133137

134138
with pytest.raises(
135139
TypeError,
136140
match="All elements in prompt list must be dictionaries",
137141
):
138-
create_normalized_cache_key([{"key": "value"}, "test"])
142+
create_normalized_cache_key([{"key": "value"}, "test"]) # type: ignore
139143

140144
with pytest.raises(
141145
TypeError,
142146
match="All elements in prompt list must be dictionaries",
143147
):
144-
create_normalized_cache_key([123, 456])
148+
create_normalized_cache_key([123, 456]) # type: ignore
145149

146150
def test_extract_llm_stats_for_cache_with_llm_call_info(self):
147151
llm_call_info = LLMCallInfo(task="test_task")
@@ -186,7 +190,7 @@ def test_restore_llm_stats_from_cache_creates_new_llm_stats(self):
186190
llm_stats_var.set(None)
187191
llm_call_info_var.set(None)
188192

189-
cached_stats = {
193+
cached_stats: LLMStatsDict = {
190194
"total_tokens": 100,
191195
"prompt_tokens": 50,
192196
"completion_tokens": 50,
@@ -211,7 +215,7 @@ def test_restore_llm_stats_from_cache_updates_existing_llm_stats(self):
211215
llm_stats.inc("total_tokens", 200)
212216
llm_stats_var.set(llm_stats)
213217

214-
cached_stats = {
218+
cached_stats: LLMStatsDict = {
215219
"total_tokens": 100,
216220
"prompt_tokens": 50,
217221
"completion_tokens": 50,
@@ -220,6 +224,7 @@ def test_restore_llm_stats_from_cache_updates_existing_llm_stats(self):
220224
restore_llm_stats_from_cache(cached_stats, cache_read_duration=0.5)
221225

222226
llm_stats = llm_stats_var.get()
227+
assert llm_stats is not None
223228
assert llm_stats.get_stat("total_calls") == 6
224229
assert llm_stats.get_stat("total_time") == 1.5
225230
assert llm_stats.get_stat("total_tokens") == 300
@@ -231,7 +236,7 @@ def test_restore_llm_stats_from_cache_updates_llm_call_info(self):
231236
llm_call_info_var.set(llm_call_info)
232237
llm_stats_var.set(None)
233238

234-
cached_stats = {
239+
cached_stats: LLMStatsDict = {
235240
"total_tokens": 100,
236241
"prompt_tokens": 50,
237242
"completion_tokens": 50,
@@ -273,6 +278,7 @@ def test_get_from_cache_and_restore_stats_cache_hit(self):
273278
"prompt_tokens": 50,
274279
"completion_tokens": 50,
275280
},
281+
"llm_metadata": None,
276282
}
277283
cache.put("test_key", cache_entry)
278284

@@ -291,6 +297,7 @@ def test_get_from_cache_and_restore_stats_cache_hit(self):
291297
assert llm_stats.get_stat("total_tokens") == 100
292298

293299
updated_info = llm_call_info_var.get()
300+
assert updated_info is not None
294301
assert updated_info.from_cache is True
295302

296303
llm_call_info_var.set(None)
@@ -301,6 +308,7 @@ def test_get_from_cache_and_restore_stats_without_llm_stats(self):
301308
cache_entry = {
302309
"result": {"allowed": False, "policy_violations": ["policy1"]},
303310
"llm_stats": None,
311+
"llm_metadata": None,
304312
}
305313
cache.put("test_key", cache_entry)
306314

@@ -324,6 +332,7 @@ def test_get_from_cache_and_restore_stats_with_processing_log(self):
324332
"prompt_tokens": 60,
325333
"completion_tokens": 20,
326334
},
335+
"llm_metadata": None,
327336
}
328337
cache.put("test_key", cache_entry)
329338

@@ -332,14 +341,15 @@ def test_get_from_cache_and_restore_stats_with_processing_log(self):
332341
llm_stats_var.set(None)
333342

334343
processing_log = []
335-
processing_log_var.set(processing_log)
344+
processing_log_var.set(processing_log) # type: ignore
336345

337346
result = get_from_cache_and_restore_stats(cache, "test_key")
338347

339348
assert result is not None
340349
assert result == {"allowed": True, "policy_violations": []}
341350

342351
retrieved_log = processing_log_var.get()
352+
assert retrieved_log is not None
343353
assert len(retrieved_log) == 1
344354
assert retrieved_log[0]["type"] == "llm_call_info"
345355
assert "timestamp" in retrieved_log[0]
@@ -359,6 +369,7 @@ def test_get_from_cache_and_restore_stats_without_processing_log(self):
359369
"prompt_tokens": 30,
360370
"completion_tokens": 20,
361371
},
372+
"llm_metadata": None,
362373
}
363374
cache.put("test_key", cache_entry)
364375

@@ -374,3 +385,91 @@ def test_get_from_cache_and_restore_stats_without_processing_log(self):
374385

375386
llm_call_info_var.set(None)
376387
llm_stats_var.set(None)
388+
389+
def test_extract_llm_metadata_for_cache_with_model_info(self):
390+
llm_call_info = LLMCallInfo(task="test_task")
391+
llm_call_info.llm_model_name = "gpt-4"
392+
llm_call_info.llm_provider_name = "openai"
393+
llm_call_info_var.set(llm_call_info)
394+
395+
metadata = extract_llm_metadata_for_cache()
396+
397+
assert metadata is not None
398+
assert metadata["model_name"] == "gpt-4"
399+
assert metadata["provider_name"] == "openai"
400+
401+
llm_call_info_var.set(None)
402+
403+
def test_extract_llm_metadata_for_cache_without_llm_call_info(self):
404+
llm_call_info_var.set(None)
405+
406+
metadata = extract_llm_metadata_for_cache()
407+
408+
assert metadata is None
409+
410+
def test_extract_llm_metadata_for_cache_with_defaults(self):
411+
llm_call_info = LLMCallInfo(task="test_task")
412+
llm_call_info_var.set(llm_call_info)
413+
414+
metadata = extract_llm_metadata_for_cache()
415+
416+
assert metadata is not None
417+
assert metadata["model_name"] == "unknown"
418+
assert metadata["provider_name"] == "unknown"
419+
420+
llm_call_info_var.set(None)
421+
422+
def test_restore_llm_metadata_from_cache(self):
423+
llm_call_info = LLMCallInfo(task="test_task")
424+
llm_call_info_var.set(llm_call_info)
425+
426+
cached_metadata: LLMMetadataDict = {
427+
"model_name": "nvidia/llama-3.1-nemoguard-8b-content-safety",
428+
"provider_name": "nim",
429+
}
430+
431+
restore_llm_metadata_from_cache(cached_metadata)
432+
433+
updated_info = llm_call_info_var.get()
434+
assert updated_info is not None
435+
assert (
436+
updated_info.llm_model_name
437+
== "nvidia/llama-3.1-nemoguard-8b-content-safety"
438+
)
439+
assert updated_info.llm_provider_name == "nim"
440+
441+
llm_call_info_var.set(None)
442+
443+
def test_get_from_cache_and_restore_stats_with_metadata(self):
444+
cache = LFUCache(maxsize=10)
445+
cache_entry: CacheEntry = {
446+
"result": {"allowed": True, "policy_violations": []},
447+
"llm_stats": {
448+
"total_tokens": 100,
449+
"prompt_tokens": 50,
450+
"completion_tokens": 50,
451+
},
452+
"llm_metadata": {
453+
"model_name": "gpt-4-turbo",
454+
"provider_name": "openai",
455+
},
456+
}
457+
cache.put("test_key", cache_entry)
458+
459+
llm_call_info = LLMCallInfo(task="test_task")
460+
llm_call_info_var.set(llm_call_info)
461+
llm_stats_var.set(None)
462+
463+
result = get_from_cache_and_restore_stats(cache, "test_key")
464+
465+
assert result is not None
466+
assert result == {"allowed": True, "policy_violations": []}
467+
468+
updated_info = llm_call_info_var.get()
469+
assert updated_info is not None
470+
assert updated_info.from_cache is True
471+
assert updated_info.llm_model_name == "gpt-4-turbo"
472+
assert updated_info.llm_provider_name == "openai"
473+
474+
llm_call_info_var.set(None)
475+
llm_stats_var.set(None)

0 commit comments

Comments
 (0)