1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16- from unittest .mock import MagicMock
1716
1817import pytest
1918
2019from nemoguardrails .context import llm_call_info_var , llm_stats_var
2120from nemoguardrails .llm .cache .lfu import LFUCache
2221from nemoguardrails .llm .cache .utils import (
22+ CacheEntry ,
23+ LLMMetadataDict ,
24+ LLMStatsDict ,
2325 create_normalized_cache_key ,
26+ extract_llm_metadata_for_cache ,
2427 extract_llm_stats_for_cache ,
2528 get_from_cache_and_restore_stats ,
29+ restore_llm_metadata_from_cache ,
2630 restore_llm_stats_from_cache ,
2731)
2832from nemoguardrails .logging .explain import LLMCallInfo
@@ -97,10 +101,10 @@ def test_create_normalized_cache_key_different_for_different_input(
97101
98102 def test_create_normalized_cache_key_invalid_type_raises_error (self ):
99103 with pytest .raises (TypeError , match = "Invalid type for prompt: int" ):
100- create_normalized_cache_key (123 )
104+ create_normalized_cache_key (123 ) # type: ignore
101105
102106 with pytest .raises (TypeError , match = "Invalid type for prompt: dict" ):
103- create_normalized_cache_key ({"key" : "value" })
107+ create_normalized_cache_key ({"key" : "value" }) # type: ignore
104108
105109 def test_create_normalized_cache_key_list_of_dicts (self ):
106110 messages = [
@@ -129,19 +133,19 @@ def test_create_normalized_cache_key_invalid_list_raises_error(self):
129133 TypeError ,
130134 match = "All elements in prompt list must be dictionaries" ,
131135 ):
132- create_normalized_cache_key (["hello" , "world" ])
136+ create_normalized_cache_key (["hello" , "world" ]) # type: ignore
133137
134138 with pytest .raises (
135139 TypeError ,
136140 match = "All elements in prompt list must be dictionaries" ,
137141 ):
138- create_normalized_cache_key ([{"key" : "value" }, "test" ])
142+ create_normalized_cache_key ([{"key" : "value" }, "test" ]) # type: ignore
139143
140144 with pytest .raises (
141145 TypeError ,
142146 match = "All elements in prompt list must be dictionaries" ,
143147 ):
144- create_normalized_cache_key ([123 , 456 ])
148+ create_normalized_cache_key ([123 , 456 ]) # type: ignore
145149
146150 def test_extract_llm_stats_for_cache_with_llm_call_info (self ):
147151 llm_call_info = LLMCallInfo (task = "test_task" )
@@ -186,7 +190,7 @@ def test_restore_llm_stats_from_cache_creates_new_llm_stats(self):
186190 llm_stats_var .set (None )
187191 llm_call_info_var .set (None )
188192
189- cached_stats = {
193+ cached_stats : LLMStatsDict = {
190194 "total_tokens" : 100 ,
191195 "prompt_tokens" : 50 ,
192196 "completion_tokens" : 50 ,
@@ -211,7 +215,7 @@ def test_restore_llm_stats_from_cache_updates_existing_llm_stats(self):
211215 llm_stats .inc ("total_tokens" , 200 )
212216 llm_stats_var .set (llm_stats )
213217
214- cached_stats = {
218+ cached_stats : LLMStatsDict = {
215219 "total_tokens" : 100 ,
216220 "prompt_tokens" : 50 ,
217221 "completion_tokens" : 50 ,
@@ -220,6 +224,7 @@ def test_restore_llm_stats_from_cache_updates_existing_llm_stats(self):
220224 restore_llm_stats_from_cache (cached_stats , cache_read_duration = 0.5 )
221225
222226 llm_stats = llm_stats_var .get ()
227+ assert llm_stats is not None
223228 assert llm_stats .get_stat ("total_calls" ) == 6
224229 assert llm_stats .get_stat ("total_time" ) == 1.5
225230 assert llm_stats .get_stat ("total_tokens" ) == 300
@@ -231,7 +236,7 @@ def test_restore_llm_stats_from_cache_updates_llm_call_info(self):
231236 llm_call_info_var .set (llm_call_info )
232237 llm_stats_var .set (None )
233238
234- cached_stats = {
239+ cached_stats : LLMStatsDict = {
235240 "total_tokens" : 100 ,
236241 "prompt_tokens" : 50 ,
237242 "completion_tokens" : 50 ,
@@ -273,6 +278,7 @@ def test_get_from_cache_and_restore_stats_cache_hit(self):
273278 "prompt_tokens" : 50 ,
274279 "completion_tokens" : 50 ,
275280 },
281+ "llm_metadata" : None ,
276282 }
277283 cache .put ("test_key" , cache_entry )
278284
@@ -291,6 +297,7 @@ def test_get_from_cache_and_restore_stats_cache_hit(self):
291297 assert llm_stats .get_stat ("total_tokens" ) == 100
292298
293299 updated_info = llm_call_info_var .get ()
300+ assert updated_info is not None
294301 assert updated_info .from_cache is True
295302
296303 llm_call_info_var .set (None )
@@ -301,6 +308,7 @@ def test_get_from_cache_and_restore_stats_without_llm_stats(self):
301308 cache_entry = {
302309 "result" : {"allowed" : False , "policy_violations" : ["policy1" ]},
303310 "llm_stats" : None ,
311+ "llm_metadata" : None ,
304312 }
305313 cache .put ("test_key" , cache_entry )
306314
@@ -324,6 +332,7 @@ def test_get_from_cache_and_restore_stats_with_processing_log(self):
324332 "prompt_tokens" : 60 ,
325333 "completion_tokens" : 20 ,
326334 },
335+ "llm_metadata" : None ,
327336 }
328337 cache .put ("test_key" , cache_entry )
329338
@@ -332,14 +341,15 @@ def test_get_from_cache_and_restore_stats_with_processing_log(self):
332341 llm_stats_var .set (None )
333342
334343 processing_log = []
335- processing_log_var .set (processing_log )
344+ processing_log_var .set (processing_log ) # type: ignore
336345
337346 result = get_from_cache_and_restore_stats (cache , "test_key" )
338347
339348 assert result is not None
340349 assert result == {"allowed" : True , "policy_violations" : []}
341350
342351 retrieved_log = processing_log_var .get ()
352+ assert retrieved_log is not None
343353 assert len (retrieved_log ) == 1
344354 assert retrieved_log [0 ]["type" ] == "llm_call_info"
345355 assert "timestamp" in retrieved_log [0 ]
@@ -359,6 +369,7 @@ def test_get_from_cache_and_restore_stats_without_processing_log(self):
359369 "prompt_tokens" : 30 ,
360370 "completion_tokens" : 20 ,
361371 },
372+ "llm_metadata" : None ,
362373 }
363374 cache .put ("test_key" , cache_entry )
364375
@@ -374,3 +385,91 @@ def test_get_from_cache_and_restore_stats_without_processing_log(self):
374385
375386 llm_call_info_var .set (None )
376387 llm_stats_var .set (None )
388+
389+ def test_extract_llm_metadata_for_cache_with_model_info (self ):
390+ llm_call_info = LLMCallInfo (task = "test_task" )
391+ llm_call_info .llm_model_name = "gpt-4"
392+ llm_call_info .llm_provider_name = "openai"
393+ llm_call_info_var .set (llm_call_info )
394+
395+ metadata = extract_llm_metadata_for_cache ()
396+
397+ assert metadata is not None
398+ assert metadata ["model_name" ] == "gpt-4"
399+ assert metadata ["provider_name" ] == "openai"
400+
401+ llm_call_info_var .set (None )
402+
403+ def test_extract_llm_metadata_for_cache_without_llm_call_info (self ):
404+ llm_call_info_var .set (None )
405+
406+ metadata = extract_llm_metadata_for_cache ()
407+
408+ assert metadata is None
409+
410+ def test_extract_llm_metadata_for_cache_with_defaults (self ):
411+ llm_call_info = LLMCallInfo (task = "test_task" )
412+ llm_call_info_var .set (llm_call_info )
413+
414+ metadata = extract_llm_metadata_for_cache ()
415+
416+ assert metadata is not None
417+ assert metadata ["model_name" ] == "unknown"
418+ assert metadata ["provider_name" ] == "unknown"
419+
420+ llm_call_info_var .set (None )
421+
422+ def test_restore_llm_metadata_from_cache (self ):
423+ llm_call_info = LLMCallInfo (task = "test_task" )
424+ llm_call_info_var .set (llm_call_info )
425+
426+ cached_metadata : LLMMetadataDict = {
427+ "model_name" : "nvidia/llama-3.1-nemoguard-8b-content-safety" ,
428+ "provider_name" : "nim" ,
429+ }
430+
431+ restore_llm_metadata_from_cache (cached_metadata )
432+
433+ updated_info = llm_call_info_var .get ()
434+ assert updated_info is not None
435+ assert (
436+ updated_info .llm_model_name
437+ == "nvidia/llama-3.1-nemoguard-8b-content-safety"
438+ )
439+ assert updated_info .llm_provider_name == "nim"
440+
441+ llm_call_info_var .set (None )
442+
443+ def test_get_from_cache_and_restore_stats_with_metadata (self ):
444+ cache = LFUCache (maxsize = 10 )
445+ cache_entry : CacheEntry = {
446+ "result" : {"allowed" : True , "policy_violations" : []},
447+ "llm_stats" : {
448+ "total_tokens" : 100 ,
449+ "prompt_tokens" : 50 ,
450+ "completion_tokens" : 50 ,
451+ },
452+ "llm_metadata" : {
453+ "model_name" : "gpt-4-turbo" ,
454+ "provider_name" : "openai" ,
455+ },
456+ }
457+ cache .put ("test_key" , cache_entry )
458+
459+ llm_call_info = LLMCallInfo (task = "test_task" )
460+ llm_call_info_var .set (llm_call_info )
461+ llm_stats_var .set (None )
462+
463+ result = get_from_cache_and_restore_stats (cache , "test_key" )
464+
465+ assert result is not None
466+ assert result == {"allowed" : True , "policy_violations" : []}
467+
468+ updated_info = llm_call_info_var .get ()
469+ assert updated_info is not None
470+ assert updated_info .from_cache is True
471+ assert updated_info .llm_model_name == "gpt-4-turbo"
472+ assert updated_info .llm_provider_name == "openai"
473+
474+ llm_call_info_var .set (None )
475+ llm_stats_var .set (None )
0 commit comments