1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16- from unittest .mock import MagicMock
1716
1817import pytest
1918
2019from nemoguardrails .context import llm_call_info_var , llm_stats_var
2120from nemoguardrails .llm .cache .lfu import LFUCache
2221from nemoguardrails .llm .cache .utils import (
22+ CacheEntry ,
23+ LLMMetadataDict ,
24+ LLMStatsDict ,
2325 create_normalized_cache_key ,
26+ extract_llm_metadata_for_cache ,
2427 extract_llm_stats_for_cache ,
2528 get_from_cache_and_restore_stats ,
29+ restore_llm_metadata_from_cache ,
2630 restore_llm_stats_from_cache ,
2731)
2832from nemoguardrails .logging .explain import LLMCallInfo
3135
3236
3337class TestCacheUtils :
38+ @pytest .fixture (autouse = True )
39+ def isolated_llm_call_info_var (self ):
40+ llm_call_info_var .set (None )
41+ yield
42+ llm_call_info_var .set (None )
43+
3444 def test_create_normalized_cache_key_returns_sha256_hash (self ):
3545 key = create_normalized_cache_key ("Hello world" )
3646 assert len (key ) == 64
@@ -97,10 +107,10 @@ def test_create_normalized_cache_key_different_for_different_input(
97107
98108 def test_create_normalized_cache_key_invalid_type_raises_error (self ):
99109 with pytest .raises (TypeError , match = "Invalid type for prompt: int" ):
100- create_normalized_cache_key (123 )
110+ create_normalized_cache_key (123 ) # type: ignore
101111
102112 with pytest .raises (TypeError , match = "Invalid type for prompt: dict" ):
103- create_normalized_cache_key ({"key" : "value" })
113+ create_normalized_cache_key ({"key" : "value" }) # type: ignore
104114
105115 def test_create_normalized_cache_key_list_of_dicts (self ):
106116 messages = [
@@ -129,25 +139,24 @@ def test_create_normalized_cache_key_invalid_list_raises_error(self):
129139 TypeError ,
130140 match = "All elements in prompt list must be dictionaries" ,
131141 ):
132- create_normalized_cache_key (["hello" , "world" ])
142+ create_normalized_cache_key (["hello" , "world" ]) # type: ignore
133143
134144 with pytest .raises (
135145 TypeError ,
136146 match = "All elements in prompt list must be dictionaries" ,
137147 ):
138- create_normalized_cache_key ([{"key" : "value" }, "test" ])
148+ create_normalized_cache_key ([{"key" : "value" }, "test" ]) # type: ignore
139149
140150 with pytest .raises (
141151 TypeError ,
142152 match = "All elements in prompt list must be dictionaries" ,
143153 ):
144- create_normalized_cache_key ([123 , 456 ])
154+ create_normalized_cache_key ([123 , 456 ]) # type: ignore
145155
146156 def test_extract_llm_stats_for_cache_with_llm_call_info (self ):
147- llm_call_info = LLMCallInfo (task = "test_task" )
148- llm_call_info .total_tokens = 100
149- llm_call_info .prompt_tokens = 50
150- llm_call_info .completion_tokens = 50
157+ llm_call_info = LLMCallInfo (
158+ task = "test_task" , total_tokens = 100 , prompt_tokens = 50 , completion_tokens = 50
159+ )
151160 llm_call_info_var .set (llm_call_info )
152161
153162 stats = extract_llm_stats_for_cache ()
@@ -167,10 +176,12 @@ def test_extract_llm_stats_for_cache_without_llm_call_info(self):
167176 assert stats is None
168177
169178 def test_extract_llm_stats_for_cache_with_none_values (self ):
170- llm_call_info = LLMCallInfo (task = "test_task" )
171- llm_call_info .total_tokens = None
172- llm_call_info .prompt_tokens = None
173- llm_call_info .completion_tokens = None
179+ llm_call_info = LLMCallInfo (
180+ task = "test_task" ,
181+ total_tokens = None ,
182+ prompt_tokens = None ,
183+ completion_tokens = None ,
184+ )
174185 llm_call_info_var .set (llm_call_info )
175186
176187 stats = extract_llm_stats_for_cache ()
@@ -186,13 +197,13 @@ def test_restore_llm_stats_from_cache_creates_new_llm_stats(self):
186197 llm_stats_var .set (None )
187198 llm_call_info_var .set (None )
188199
189- cached_stats = {
200+ cached_stats : LLMStatsDict = {
190201 "total_tokens" : 100 ,
191202 "prompt_tokens" : 50 ,
192203 "completion_tokens" : 50 ,
193204 }
194205
195- restore_llm_stats_from_cache (cached_stats , cache_read_duration = 0.01 )
206+ restore_llm_stats_from_cache (cached_stats , cache_read_duration_s = 0.01 )
196207
197208 llm_stats = llm_stats_var .get ()
198209 assert llm_stats is not None
@@ -211,15 +222,16 @@ def test_restore_llm_stats_from_cache_updates_existing_llm_stats(self):
211222 llm_stats .inc ("total_tokens" , 200 )
212223 llm_stats_var .set (llm_stats )
213224
214- cached_stats = {
225+ cached_stats : LLMStatsDict = {
215226 "total_tokens" : 100 ,
216227 "prompt_tokens" : 50 ,
217228 "completion_tokens" : 50 ,
218229 }
219230
220- restore_llm_stats_from_cache (cached_stats , cache_read_duration = 0.5 )
231+ restore_llm_stats_from_cache (cached_stats , cache_read_duration_s = 0.5 )
221232
222233 llm_stats = llm_stats_var .get ()
234+ assert llm_stats is not None
223235 assert llm_stats .get_stat ("total_calls" ) == 6
224236 assert llm_stats .get_stat ("total_time" ) == 1.5
225237 assert llm_stats .get_stat ("total_tokens" ) == 300
@@ -231,13 +243,13 @@ def test_restore_llm_stats_from_cache_updates_llm_call_info(self):
231243 llm_call_info_var .set (llm_call_info )
232244 llm_stats_var .set (None )
233245
234- cached_stats = {
246+ cached_stats : LLMStatsDict = {
235247 "total_tokens" : 100 ,
236248 "prompt_tokens" : 50 ,
237249 "completion_tokens" : 50 ,
238250 }
239251
240- restore_llm_stats_from_cache (cached_stats , cache_read_duration = 0.02 )
252+ restore_llm_stats_from_cache (cached_stats , cache_read_duration_s = 0.02 )
241253
242254 updated_info = llm_call_info_var .get ()
243255 assert updated_info is not None
@@ -273,6 +285,7 @@ def test_get_from_cache_and_restore_stats_cache_hit(self):
273285 "prompt_tokens" : 50 ,
274286 "completion_tokens" : 50 ,
275287 },
288+ "llm_metadata" : None ,
276289 }
277290 cache .put ("test_key" , cache_entry )
278291
@@ -291,6 +304,7 @@ def test_get_from_cache_and_restore_stats_cache_hit(self):
291304 assert llm_stats .get_stat ("total_tokens" ) == 100
292305
293306 updated_info = llm_call_info_var .get ()
307+ assert updated_info is not None
294308 assert updated_info .from_cache is True
295309
296310 llm_call_info_var .set (None )
@@ -301,6 +315,7 @@ def test_get_from_cache_and_restore_stats_without_llm_stats(self):
301315 cache_entry = {
302316 "result" : {"allowed" : False , "policy_violations" : ["policy1" ]},
303317 "llm_stats" : None ,
318+ "llm_metadata" : None ,
304319 }
305320 cache .put ("test_key" , cache_entry )
306321
@@ -324,6 +339,7 @@ def test_get_from_cache_and_restore_stats_with_processing_log(self):
324339 "prompt_tokens" : 60 ,
325340 "completion_tokens" : 20 ,
326341 },
342+ "llm_metadata" : None ,
327343 }
328344 cache .put ("test_key" , cache_entry )
329345
@@ -332,14 +348,15 @@ def test_get_from_cache_and_restore_stats_with_processing_log(self):
332348 llm_stats_var .set (None )
333349
334350 processing_log = []
335- processing_log_var .set (processing_log )
351+ processing_log_var .set (processing_log ) # type: ignore
336352
337353 result = get_from_cache_and_restore_stats (cache , "test_key" )
338354
339355 assert result is not None
340356 assert result == {"allowed" : True , "policy_violations" : []}
341357
342358 retrieved_log = processing_log_var .get ()
359+ assert retrieved_log is not None
343360 assert len (retrieved_log ) == 1
344361 assert retrieved_log [0 ]["type" ] == "llm_call_info"
345362 assert "timestamp" in retrieved_log [0 ]
@@ -359,6 +376,7 @@ def test_get_from_cache_and_restore_stats_without_processing_log(self):
359376 "prompt_tokens" : 30 ,
360377 "completion_tokens" : 20 ,
361378 },
379+ "llm_metadata" : None ,
362380 }
363381 cache .put ("test_key" , cache_entry )
364382
@@ -374,3 +392,91 @@ def test_get_from_cache_and_restore_stats_without_processing_log(self):
374392
375393 llm_call_info_var .set (None )
376394 llm_stats_var .set (None )
395+
396+ def test_extract_llm_metadata_for_cache_with_model_info (self ):
397+ llm_call_info = LLMCallInfo (
398+ task = "test_task" , llm_model_name = "gpt-4" , llm_provider_name = "openai"
399+ )
400+ llm_call_info_var .set (llm_call_info )
401+
402+ metadata = extract_llm_metadata_for_cache ()
403+
404+ assert metadata is not None
405+ assert metadata ["model_name" ] == "gpt-4"
406+ assert metadata ["provider_name" ] == "openai"
407+
408+ llm_call_info_var .set (None )
409+
410+ def test_extract_llm_metadata_for_cache_without_llm_call_info (self ):
411+ llm_call_info_var .set (None )
412+
413+ metadata = extract_llm_metadata_for_cache ()
414+
415+ assert metadata is None
416+
417+ def test_extract_llm_metadata_for_cache_with_defaults (self ):
418+ llm_call_info = LLMCallInfo (task = "test_task" )
419+ llm_call_info_var .set (llm_call_info )
420+
421+ metadata = extract_llm_metadata_for_cache ()
422+
423+ assert metadata is not None
424+ assert metadata ["model_name" ] == "unknown"
425+ assert metadata ["provider_name" ] == "unknown"
426+
427+ llm_call_info_var .set (None )
428+
429+ def test_restore_llm_metadata_from_cache (self ):
430+ llm_call_info = LLMCallInfo (task = "test_task" )
431+ llm_call_info_var .set (llm_call_info )
432+
433+ cached_metadata : LLMMetadataDict = {
434+ "model_name" : "nvidia/llama-3.1-nemoguard-8b-content-safety" ,
435+ "provider_name" : "nim" ,
436+ }
437+
438+ restore_llm_metadata_from_cache (cached_metadata )
439+
440+ updated_info = llm_call_info_var .get ()
441+ assert updated_info is not None
442+ assert (
443+ updated_info .llm_model_name
444+ == "nvidia/llama-3.1-nemoguard-8b-content-safety"
445+ )
446+ assert updated_info .llm_provider_name == "nim"
447+
448+ llm_call_info_var .set (None )
449+
450+ def test_get_from_cache_and_restore_stats_with_metadata (self ):
451+ cache = LFUCache (maxsize = 10 )
452+ cache_entry : CacheEntry = {
453+ "result" : {"allowed" : True , "policy_violations" : []},
454+ "llm_stats" : {
455+ "total_tokens" : 100 ,
456+ "prompt_tokens" : 50 ,
457+ "completion_tokens" : 50 ,
458+ },
459+ "llm_metadata" : {
460+ "model_name" : "gpt-4-turbo" ,
461+ "provider_name" : "openai" ,
462+ },
463+ }
464+ cache .put ("test_key" , cache_entry )
465+
466+ llm_call_info = LLMCallInfo (task = "test_task" )
467+ llm_call_info_var .set (llm_call_info )
468+ llm_stats_var .set (None )
469+
470+ result = get_from_cache_and_restore_stats (cache , "test_key" )
471+
472+ assert result is not None
473+ assert result == {"allowed" : True , "policy_violations" : []}
474+
475+ updated_info = llm_call_info_var .get ()
476+ assert updated_info is not None
477+ assert updated_info .from_cache is True
478+ assert updated_info .llm_model_name == "gpt-4-turbo"
479+ assert updated_info .llm_provider_name == "openai"
480+
481+ llm_call_info_var .set (None )
482+ llm_stats_var .set (None )
0 commit comments