Skip to content

Commit fd873b7

Browse files
committed
address review comments: add _s suffix to durations, add test fixture, refactor LLMCallInfo instantiation
1 parent a997dd3 commit fd873b7

File tree

2 files changed

+28
-21
lines changed

2 files changed

+28
-21
lines changed

nemoguardrails/llm/cache/utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,27 +106,27 @@ def create_normalized_cache_key(
106106

107107

108108
def restore_llm_stats_from_cache(
109-
cached_stats: LLMStatsDict, cache_read_duration: float
109+
cached_stats: LLMStatsDict, cache_read_duration_s: float
110110
) -> None:
111111
llm_stats = llm_stats_var.get()
112112
if llm_stats is None:
113113
llm_stats = LLMStats()
114114
llm_stats_var.set(llm_stats)
115115

116116
llm_stats.inc("total_calls")
117-
llm_stats.inc("total_time", cache_read_duration)
117+
llm_stats.inc("total_time", cache_read_duration_s)
118118
llm_stats.inc("total_tokens", cached_stats.get("total_tokens", 0))
119119
llm_stats.inc("total_prompt_tokens", cached_stats.get("prompt_tokens", 0))
120120
llm_stats.inc("total_completion_tokens", cached_stats.get("completion_tokens", 0))
121121

122122
llm_call_info = llm_call_info_var.get()
123123
if llm_call_info:
124-
llm_call_info.duration = cache_read_duration
124+
llm_call_info.duration = cache_read_duration_s
125125
llm_call_info.total_tokens = cached_stats.get("total_tokens", 0)
126126
llm_call_info.prompt_tokens = cached_stats.get("prompt_tokens", 0)
127127
llm_call_info.completion_tokens = cached_stats.get("completion_tokens", 0)
128128
llm_call_info.from_cache = True
129-
llm_call_info.started_at = time() - cache_read_duration
129+
llm_call_info.started_at = time() - cache_read_duration_s
130130
llm_call_info.finished_at = time()
131131

132132

@@ -167,14 +167,14 @@ def get_from_cache_and_restore_stats(
167167
if cached_entry is None:
168168
return None
169169

170-
cache_read_start = time()
170+
cache_read_start_s = time()
171171
final_result = cached_entry["result"]
172172
cached_stats = cached_entry.get("llm_stats")
173173
cached_metadata = cached_entry.get("llm_metadata")
174-
cache_read_duration = time() - cache_read_start
174+
cache_read_duration_s = time() - cache_read_start_s
175175

176176
if cached_stats:
177-
restore_llm_stats_from_cache(cached_stats, cache_read_duration)
177+
restore_llm_stats_from_cache(cached_stats, cache_read_duration_s)
178178

179179
if cached_metadata:
180180
restore_llm_metadata_from_cache(cached_metadata)

tests/test_cache_utils.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@
3535

3636

3737
class TestCacheUtils:
38+
@pytest.fixture(autouse=True)
39+
def isolated_llm_call_info_var(self):
40+
llm_call_info_var.set(None)
41+
yield
42+
llm_call_info_var.set(None)
43+
3844
def test_create_normalized_cache_key_returns_sha256_hash(self):
3945
key = create_normalized_cache_key("Hello world")
4046
assert len(key) == 64
@@ -148,10 +154,9 @@ def test_create_normalized_cache_key_invalid_list_raises_error(self):
148154
create_normalized_cache_key([123, 456]) # type: ignore
149155

150156
def test_extract_llm_stats_for_cache_with_llm_call_info(self):
151-
llm_call_info = LLMCallInfo(task="test_task")
152-
llm_call_info.total_tokens = 100
153-
llm_call_info.prompt_tokens = 50
154-
llm_call_info.completion_tokens = 50
157+
llm_call_info = LLMCallInfo(
158+
task="test_task", total_tokens=100, prompt_tokens=50, completion_tokens=50
159+
)
155160
llm_call_info_var.set(llm_call_info)
156161

157162
stats = extract_llm_stats_for_cache()
@@ -171,10 +176,12 @@ def test_extract_llm_stats_for_cache_without_llm_call_info(self):
171176
assert stats is None
172177

173178
def test_extract_llm_stats_for_cache_with_none_values(self):
174-
llm_call_info = LLMCallInfo(task="test_task")
175-
llm_call_info.total_tokens = None
176-
llm_call_info.prompt_tokens = None
177-
llm_call_info.completion_tokens = None
179+
llm_call_info = LLMCallInfo(
180+
task="test_task",
181+
total_tokens=None,
182+
prompt_tokens=None,
183+
completion_tokens=None,
184+
)
178185
llm_call_info_var.set(llm_call_info)
179186

180187
stats = extract_llm_stats_for_cache()
@@ -196,7 +203,7 @@ def test_restore_llm_stats_from_cache_creates_new_llm_stats(self):
196203
"completion_tokens": 50,
197204
}
198205

199-
restore_llm_stats_from_cache(cached_stats, cache_read_duration=0.01)
206+
restore_llm_stats_from_cache(cached_stats, cache_read_duration_s=0.01)
200207

201208
llm_stats = llm_stats_var.get()
202209
assert llm_stats is not None
@@ -221,7 +228,7 @@ def test_restore_llm_stats_from_cache_updates_existing_llm_stats(self):
221228
"completion_tokens": 50,
222229
}
223230

224-
restore_llm_stats_from_cache(cached_stats, cache_read_duration=0.5)
231+
restore_llm_stats_from_cache(cached_stats, cache_read_duration_s=0.5)
225232

226233
llm_stats = llm_stats_var.get()
227234
assert llm_stats is not None
@@ -242,7 +249,7 @@ def test_restore_llm_stats_from_cache_updates_llm_call_info(self):
242249
"completion_tokens": 50,
243250
}
244251

245-
restore_llm_stats_from_cache(cached_stats, cache_read_duration=0.02)
252+
restore_llm_stats_from_cache(cached_stats, cache_read_duration_s=0.02)
246253

247254
updated_info = llm_call_info_var.get()
248255
assert updated_info is not None
@@ -387,9 +394,9 @@ def test_get_from_cache_and_restore_stats_without_processing_log(self):
387394
llm_stats_var.set(None)
388395

389396
def test_extract_llm_metadata_for_cache_with_model_info(self):
390-
llm_call_info = LLMCallInfo(task="test_task")
391-
llm_call_info.llm_model_name = "gpt-4"
392-
llm_call_info.llm_provider_name = "openai"
397+
llm_call_info = LLMCallInfo(
398+
task="test_task", llm_model_name="gpt-4", llm_provider_name="openai"
399+
)
393400
llm_call_info_var.set(llm_call_info)
394401

395402
metadata = extract_llm_metadata_for_cache()

0 commit comments

Comments
 (0)