Skip to content

Commit 0eec1f6

Browse files
committed
feat(cache): add LLM metadata caching for model and provider information (#1456)
* feat(cache): add LLM metadata caching for model and provider information Extends the cache system to store and restore LLM metadata (model name and provider name) alongside cache entries. This allows cached results to maintain provenance information about which model and provider generated the original response. - Added LLMMetadataDict and LLMCacheData TypedDict definitions for type safety - Extended CacheEntry to include optional llm_metadata field - Implemented extract_llm_metadata_for_cache() to capture model and provider info from context - Implemented restore_llm_metadata_from_cache() to restore metadata when retrieving cached results - Updated get_from_cache_and_restore_stats() to handle metadata extraction and restoration - Added comprehensive test coverage for metadata caching functionalit * address review comments: add _s suffix to durations, add test fixture, refactor LLMCallInfo instantiation
1 parent 84203a9 commit 0eec1f6

File tree

2 files changed

+168
-28
lines changed

2 files changed

+168
-28
lines changed

nemoguardrails/llm/cache/utils.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,20 @@ class LLMStatsDict(TypedDict):
3535
completion_tokens: int
3636

3737

38+
class LLMMetadataDict(TypedDict):
39+
model_name: str
40+
provider_name: str
41+
42+
43+
class LLMCacheData(TypedDict):
44+
stats: Optional[LLMStatsDict]
45+
metadata: Optional[LLMMetadataDict]
46+
47+
3848
class CacheEntry(TypedDict):
3949
result: dict
4050
llm_stats: Optional[LLMStatsDict]
51+
llm_metadata: Optional[LLMMetadataDict]
4152

4253

4354
def create_normalized_cache_key(
@@ -95,27 +106,27 @@ def create_normalized_cache_key(
95106

96107

97108
def restore_llm_stats_from_cache(
98-
cached_stats: LLMStatsDict, cache_read_duration: float
109+
cached_stats: LLMStatsDict, cache_read_duration_s: float
99110
) -> None:
100111
llm_stats = llm_stats_var.get()
101112
if llm_stats is None:
102113
llm_stats = LLMStats()
103114
llm_stats_var.set(llm_stats)
104115

105116
llm_stats.inc("total_calls")
106-
llm_stats.inc("total_time", cache_read_duration)
117+
llm_stats.inc("total_time", cache_read_duration_s)
107118
llm_stats.inc("total_tokens", cached_stats.get("total_tokens", 0))
108119
llm_stats.inc("total_prompt_tokens", cached_stats.get("prompt_tokens", 0))
109120
llm_stats.inc("total_completion_tokens", cached_stats.get("completion_tokens", 0))
110121

111122
llm_call_info = llm_call_info_var.get()
112123
if llm_call_info:
113-
llm_call_info.duration = cache_read_duration
124+
llm_call_info.duration = cache_read_duration_s
114125
llm_call_info.total_tokens = cached_stats.get("total_tokens", 0)
115126
llm_call_info.prompt_tokens = cached_stats.get("prompt_tokens", 0)
116127
llm_call_info.completion_tokens = cached_stats.get("completion_tokens", 0)
117128
llm_call_info.from_cache = True
118-
llm_call_info.started_at = time() - cache_read_duration
129+
llm_call_info.started_at = time() - cache_read_duration_s
119130
llm_call_info.finished_at = time()
120131

121132

@@ -130,20 +141,43 @@ def extract_llm_stats_for_cache() -> Optional[LLMStatsDict]:
130141
return None
131142

132143

144+
def extract_llm_metadata_for_cache() -> Optional[LLMMetadataDict]:
145+
llm_call_info = llm_call_info_var.get()
146+
if llm_call_info:
147+
return {
148+
"model_name": llm_call_info.llm_model_name or "unknown",
149+
"provider_name": llm_call_info.llm_provider_name or "unknown",
150+
}
151+
return None
152+
153+
154+
def restore_llm_metadata_from_cache(cached_metadata: LLMMetadataDict) -> None:
155+
llm_call_info = llm_call_info_var.get()
156+
if llm_call_info:
157+
llm_call_info.llm_model_name = cached_metadata.get("model_name", "unknown")
158+
llm_call_info.llm_provider_name = cached_metadata.get(
159+
"provider_name", "unknown"
160+
)
161+
162+
133163
def get_from_cache_and_restore_stats(
134164
cache: "CacheInterface", cache_key: str
135165
) -> Optional[dict]:
136166
cached_entry = cache.get(cache_key)
137167
if cached_entry is None:
138168
return None
139169

140-
cache_read_start = time()
170+
cache_read_start_s = time()
141171
final_result = cached_entry["result"]
142172
cached_stats = cached_entry.get("llm_stats")
143-
cache_read_duration = time() - cache_read_start
173+
cached_metadata = cached_entry.get("llm_metadata")
174+
cache_read_duration_s = time() - cache_read_start_s
144175

145176
if cached_stats:
146-
restore_llm_stats_from_cache(cached_stats, cache_read_duration)
177+
restore_llm_stats_from_cache(cached_stats, cache_read_duration_s)
178+
179+
if cached_metadata:
180+
restore_llm_metadata_from_cache(cached_metadata)
147181

148182
processing_log = processing_log_var.get()
149183
if processing_log is not None:

tests/test_cache_utils.py

Lines changed: 127 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,20 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from unittest.mock import MagicMock
1716

1817
import pytest
1918

2019
from nemoguardrails.context import llm_call_info_var, llm_stats_var
2120
from nemoguardrails.llm.cache.lfu import LFUCache
2221
from nemoguardrails.llm.cache.utils import (
22+
CacheEntry,
23+
LLMMetadataDict,
24+
LLMStatsDict,
2325
create_normalized_cache_key,
26+
extract_llm_metadata_for_cache,
2427
extract_llm_stats_for_cache,
2528
get_from_cache_and_restore_stats,
29+
restore_llm_metadata_from_cache,
2630
restore_llm_stats_from_cache,
2731
)
2832
from nemoguardrails.logging.explain import LLMCallInfo
@@ -31,6 +35,12 @@
3135

3236

3337
class TestCacheUtils:
38+
@pytest.fixture(autouse=True)
39+
def isolated_llm_call_info_var(self):
40+
llm_call_info_var.set(None)
41+
yield
42+
llm_call_info_var.set(None)
43+
3444
def test_create_normalized_cache_key_returns_sha256_hash(self):
3545
key = create_normalized_cache_key("Hello world")
3646
assert len(key) == 64
@@ -97,10 +107,10 @@ def test_create_normalized_cache_key_different_for_different_input(
97107

98108
def test_create_normalized_cache_key_invalid_type_raises_error(self):
99109
with pytest.raises(TypeError, match="Invalid type for prompt: int"):
100-
create_normalized_cache_key(123)
110+
create_normalized_cache_key(123) # type: ignore
101111

102112
with pytest.raises(TypeError, match="Invalid type for prompt: dict"):
103-
create_normalized_cache_key({"key": "value"})
113+
create_normalized_cache_key({"key": "value"}) # type: ignore
104114

105115
def test_create_normalized_cache_key_list_of_dicts(self):
106116
messages = [
@@ -129,25 +139,24 @@ def test_create_normalized_cache_key_invalid_list_raises_error(self):
129139
TypeError,
130140
match="All elements in prompt list must be dictionaries",
131141
):
132-
create_normalized_cache_key(["hello", "world"])
142+
create_normalized_cache_key(["hello", "world"]) # type: ignore
133143

134144
with pytest.raises(
135145
TypeError,
136146
match="All elements in prompt list must be dictionaries",
137147
):
138-
create_normalized_cache_key([{"key": "value"}, "test"])
148+
create_normalized_cache_key([{"key": "value"}, "test"]) # type: ignore
139149

140150
with pytest.raises(
141151
TypeError,
142152
match="All elements in prompt list must be dictionaries",
143153
):
144-
create_normalized_cache_key([123, 456])
154+
create_normalized_cache_key([123, 456]) # type: ignore
145155

146156
def test_extract_llm_stats_for_cache_with_llm_call_info(self):
147-
llm_call_info = LLMCallInfo(task="test_task")
148-
llm_call_info.total_tokens = 100
149-
llm_call_info.prompt_tokens = 50
150-
llm_call_info.completion_tokens = 50
157+
llm_call_info = LLMCallInfo(
158+
task="test_task", total_tokens=100, prompt_tokens=50, completion_tokens=50
159+
)
151160
llm_call_info_var.set(llm_call_info)
152161

153162
stats = extract_llm_stats_for_cache()
@@ -167,10 +176,12 @@ def test_extract_llm_stats_for_cache_without_llm_call_info(self):
167176
assert stats is None
168177

169178
def test_extract_llm_stats_for_cache_with_none_values(self):
170-
llm_call_info = LLMCallInfo(task="test_task")
171-
llm_call_info.total_tokens = None
172-
llm_call_info.prompt_tokens = None
173-
llm_call_info.completion_tokens = None
179+
llm_call_info = LLMCallInfo(
180+
task="test_task",
181+
total_tokens=None,
182+
prompt_tokens=None,
183+
completion_tokens=None,
184+
)
174185
llm_call_info_var.set(llm_call_info)
175186

176187
stats = extract_llm_stats_for_cache()
@@ -186,13 +197,13 @@ def test_restore_llm_stats_from_cache_creates_new_llm_stats(self):
186197
llm_stats_var.set(None)
187198
llm_call_info_var.set(None)
188199

189-
cached_stats = {
200+
cached_stats: LLMStatsDict = {
190201
"total_tokens": 100,
191202
"prompt_tokens": 50,
192203
"completion_tokens": 50,
193204
}
194205

195-
restore_llm_stats_from_cache(cached_stats, cache_read_duration=0.01)
206+
restore_llm_stats_from_cache(cached_stats, cache_read_duration_s=0.01)
196207

197208
llm_stats = llm_stats_var.get()
198209
assert llm_stats is not None
@@ -211,15 +222,16 @@ def test_restore_llm_stats_from_cache_updates_existing_llm_stats(self):
211222
llm_stats.inc("total_tokens", 200)
212223
llm_stats_var.set(llm_stats)
213224

214-
cached_stats = {
225+
cached_stats: LLMStatsDict = {
215226
"total_tokens": 100,
216227
"prompt_tokens": 50,
217228
"completion_tokens": 50,
218229
}
219230

220-
restore_llm_stats_from_cache(cached_stats, cache_read_duration=0.5)
231+
restore_llm_stats_from_cache(cached_stats, cache_read_duration_s=0.5)
221232

222233
llm_stats = llm_stats_var.get()
234+
assert llm_stats is not None
223235
assert llm_stats.get_stat("total_calls") == 6
224236
assert llm_stats.get_stat("total_time") == 1.5
225237
assert llm_stats.get_stat("total_tokens") == 300
@@ -231,13 +243,13 @@ def test_restore_llm_stats_from_cache_updates_llm_call_info(self):
231243
llm_call_info_var.set(llm_call_info)
232244
llm_stats_var.set(None)
233245

234-
cached_stats = {
246+
cached_stats: LLMStatsDict = {
235247
"total_tokens": 100,
236248
"prompt_tokens": 50,
237249
"completion_tokens": 50,
238250
}
239251

240-
restore_llm_stats_from_cache(cached_stats, cache_read_duration=0.02)
252+
restore_llm_stats_from_cache(cached_stats, cache_read_duration_s=0.02)
241253

242254
updated_info = llm_call_info_var.get()
243255
assert updated_info is not None
@@ -273,6 +285,7 @@ def test_get_from_cache_and_restore_stats_cache_hit(self):
273285
"prompt_tokens": 50,
274286
"completion_tokens": 50,
275287
},
288+
"llm_metadata": None,
276289
}
277290
cache.put("test_key", cache_entry)
278291

@@ -291,6 +304,7 @@ def test_get_from_cache_and_restore_stats_cache_hit(self):
291304
assert llm_stats.get_stat("total_tokens") == 100
292305

293306
updated_info = llm_call_info_var.get()
307+
assert updated_info is not None
294308
assert updated_info.from_cache is True
295309

296310
llm_call_info_var.set(None)
@@ -301,6 +315,7 @@ def test_get_from_cache_and_restore_stats_without_llm_stats(self):
301315
cache_entry = {
302316
"result": {"allowed": False, "policy_violations": ["policy1"]},
303317
"llm_stats": None,
318+
"llm_metadata": None,
304319
}
305320
cache.put("test_key", cache_entry)
306321

@@ -324,6 +339,7 @@ def test_get_from_cache_and_restore_stats_with_processing_log(self):
324339
"prompt_tokens": 60,
325340
"completion_tokens": 20,
326341
},
342+
"llm_metadata": None,
327343
}
328344
cache.put("test_key", cache_entry)
329345

@@ -332,14 +348,15 @@ def test_get_from_cache_and_restore_stats_with_processing_log(self):
332348
llm_stats_var.set(None)
333349

334350
processing_log = []
335-
processing_log_var.set(processing_log)
351+
processing_log_var.set(processing_log) # type: ignore
336352

337353
result = get_from_cache_and_restore_stats(cache, "test_key")
338354

339355
assert result is not None
340356
assert result == {"allowed": True, "policy_violations": []}
341357

342358
retrieved_log = processing_log_var.get()
359+
assert retrieved_log is not None
343360
assert len(retrieved_log) == 1
344361
assert retrieved_log[0]["type"] == "llm_call_info"
345362
assert "timestamp" in retrieved_log[0]
@@ -359,6 +376,7 @@ def test_get_from_cache_and_restore_stats_without_processing_log(self):
359376
"prompt_tokens": 30,
360377
"completion_tokens": 20,
361378
},
379+
"llm_metadata": None,
362380
}
363381
cache.put("test_key", cache_entry)
364382

@@ -374,3 +392,91 @@ def test_get_from_cache_and_restore_stats_without_processing_log(self):
374392

375393
llm_call_info_var.set(None)
376394
llm_stats_var.set(None)
395+
396+
def test_extract_llm_metadata_for_cache_with_model_info(self):
397+
llm_call_info = LLMCallInfo(
398+
task="test_task", llm_model_name="gpt-4", llm_provider_name="openai"
399+
)
400+
llm_call_info_var.set(llm_call_info)
401+
402+
metadata = extract_llm_metadata_for_cache()
403+
404+
assert metadata is not None
405+
assert metadata["model_name"] == "gpt-4"
406+
assert metadata["provider_name"] == "openai"
407+
408+
llm_call_info_var.set(None)
409+
410+
def test_extract_llm_metadata_for_cache_without_llm_call_info(self):
411+
llm_call_info_var.set(None)
412+
413+
metadata = extract_llm_metadata_for_cache()
414+
415+
assert metadata is None
416+
417+
def test_extract_llm_metadata_for_cache_with_defaults(self):
418+
llm_call_info = LLMCallInfo(task="test_task")
419+
llm_call_info_var.set(llm_call_info)
420+
421+
metadata = extract_llm_metadata_for_cache()
422+
423+
assert metadata is not None
424+
assert metadata["model_name"] == "unknown"
425+
assert metadata["provider_name"] == "unknown"
426+
427+
llm_call_info_var.set(None)
428+
429+
def test_restore_llm_metadata_from_cache(self):
430+
llm_call_info = LLMCallInfo(task="test_task")
431+
llm_call_info_var.set(llm_call_info)
432+
433+
cached_metadata: LLMMetadataDict = {
434+
"model_name": "nvidia/llama-3.1-nemoguard-8b-content-safety",
435+
"provider_name": "nim",
436+
}
437+
438+
restore_llm_metadata_from_cache(cached_metadata)
439+
440+
updated_info = llm_call_info_var.get()
441+
assert updated_info is not None
442+
assert (
443+
updated_info.llm_model_name
444+
== "nvidia/llama-3.1-nemoguard-8b-content-safety"
445+
)
446+
assert updated_info.llm_provider_name == "nim"
447+
448+
llm_call_info_var.set(None)
449+
450+
def test_get_from_cache_and_restore_stats_with_metadata(self):
451+
cache = LFUCache(maxsize=10)
452+
cache_entry: CacheEntry = {
453+
"result": {"allowed": True, "policy_violations": []},
454+
"llm_stats": {
455+
"total_tokens": 100,
456+
"prompt_tokens": 50,
457+
"completion_tokens": 50,
458+
},
459+
"llm_metadata": {
460+
"model_name": "gpt-4-turbo",
461+
"provider_name": "openai",
462+
},
463+
}
464+
cache.put("test_key", cache_entry)
465+
466+
llm_call_info = LLMCallInfo(task="test_task")
467+
llm_call_info_var.set(llm_call_info)
468+
llm_stats_var.set(None)
469+
470+
result = get_from_cache_and_restore_stats(cache, "test_key")
471+
472+
assert result is not None
473+
assert result == {"allowed": True, "policy_violations": []}
474+
475+
updated_info = llm_call_info_var.get()
476+
assert updated_info is not None
477+
assert updated_info.from_cache is True
478+
assert updated_info.llm_model_name == "gpt-4-turbo"
479+
assert updated_info.llm_provider_name == "openai"
480+
481+
llm_call_info_var.set(None)
482+
llm_stats_var.set(None)

0 commit comments

Comments
 (0)