Skip to content

Commit

Permalink
bugfix: added entry for callback to clear _node_data_buffer (run-llam…
Browse files Browse the repository at this point in the history
  • Loading branch information
azurewtl authored and Ryan Peach committed Feb 7, 2024
1 parent 3f721a8 commit 36384dc
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
10 changes: 8 additions & 2 deletions docs/examples/callbacks/OpenInferenceCallback.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
"from llama_index.callbacks.open_inference_callback import (\n",
" as_dataframe,\n",
" QueryData,\n",
" NodeData,\n",
")\n",
"from llama_index.node_parser import SimpleNodeParser\n",
"import pandas as pd\n",
Expand Down Expand Up @@ -845,13 +846,18 @@
" self._max_buffer_length = max_buffer_length\n",
" self._batch_index = 0\n",
"\n",
" def __call__(self, query_data_buffer: List[QueryData]) -> None:\n",
" def __call__(\n",
" self,\n",
" query_data_buffer: List[QueryData],\n",
" node_data_buffer: List[NodeData],\n",
" ) -> None:\n",
" if len(query_data_buffer) > self._max_buffer_length:\n",
" query_dataframe = as_dataframe(query_data_buffer)\n",
" file_path = self._data_path / f\"log-{self._batch_index}.parquet\"\n",
" query_dataframe.to_parquet(file_path)\n",
" self._batch_index += 1\n",
" query_data_buffer.clear() # ⚠️ clear the buffer or it will keep growing forever!"
" query_data_buffer.clear() # ⚠️ clear the buffer or it will keep growing forever!\n",
" node_data_buffer.clear() # didn't log node_data_buffer, but still need to clear it"
]
},
{
Expand Down
6 changes: 3 additions & 3 deletions llama_index/callbacks/open_inference_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,12 @@ class OpenInferenceCallbackHandler(BaseCallbackHandler):

def __init__(
self,
callback: Optional[Callable[[List[QueryData]], None]] = None,
callback: Optional[Callable[[List[QueryData], List[NodeData]], None]] = None,
) -> None:
"""Initializes the OpenInferenceCallbackHandler.
Args:
callback (Optional[Callable[[List[QueryData]], None]], optional): A
callback (Optional[Callable[[List[QueryData], List[NodeData]], None]], optional): A
callback function that will be called when a query trace is
completed, often used for logging or persisting query data.
"""
Expand All @@ -180,7 +180,7 @@ def end_trace(
self._node_data_buffer.extend(self._trace_data.node_datas)
self._trace_data = TraceData()
if self._callback is not None:
self._callback(self._query_data_buffer)
self._callback(self._query_data_buffer, self._node_data_buffer)

def on_event_start(
self,
Expand Down

0 comments on commit 36384dc

Please sign in to comment.