From def04ac0ce0fe6f90c66e523066c3b4517dbc8d3 Mon Sep 17 00:00:00 2001 From: JeevanReddy Date: Wed, 7 Aug 2024 13:07:18 +0530 Subject: [PATCH 1/4] openai can give multiple tool calls, current implementation assumes only one function call at a time. Fixed this to handle multiple function calls. --- src/pipecat/services/openai.py | 33 +++++++++++++++++++++++++++------ 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/src/pipecat/services/openai.py b/src/pipecat/services/openai.py index c17916f2d..b17dd7397 100644 --- a/src/pipecat/services/openai.py +++ b/src/pipecat/services/openai.py @@ -205,6 +205,10 @@ async def _stream_chat_completions( return chunks async def _process_context(self, context: OpenAILLMContext): + functions_list = [] + arguments_list = [] + tool_id_list = [] + func_idx = 0 function_name = "" arguments = "" tool_call_id = "" @@ -242,6 +246,14 @@ async def _process_context(self, context: OpenAILLMContext): # yield a frame containing the function name and the arguments. tool_call = chunk.choices[0].delta.tool_calls[0] + if tool_call.index != func_idx: + functions_list.append(function_name) + arguments_list.append(arguments) + tool_id_list.append(tool_call_id) + function_name = "" + arguments = "" + tool_call_id = "" + func_idx += 1 if tool_call.function and tool_call.function.name: function_name += tool_call.function.name tool_call_id = tool_call.id @@ -257,12 +269,21 @@ async def _process_context(self, context: OpenAILLMContext): # the context, and re-prompt to get a chat answer. If we don't have a registered # handler, raise an exception. if function_name and arguments: - if self.has_function(function_name): - await self._handle_function_call(context, tool_call_id, function_name, arguments) - else: - raise OpenAIUnhandledFunctionException( - f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function." - ) + # added to the list as last function name and arguments not added to the list + functions_list.append(function_name) + arguments_list.append(arguments) + tool_id_list.append(tool_call_id) + for function_name, arguments, tool_id in zip( + functions_list, arguments_list, tool_id_list + ): + if self.has_function(function_name): + await self._handle_function_call(context, tool_id, function_name, arguments) + else: + raise OpenAIUnhandledFunctionException( + f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function." + ) + # re-prompt to get a human answer after all the functions are called + await self._process_context(context) async def _handle_function_call(self, context, tool_call_id, function_name, arguments): arguments = json.loads(arguments) From a5c73ec829685f302b3fdb8450de2ac75297b72e Mon Sep 17 00:00:00 2001 From: Kwindla Hultman Kramer Date: Sun, 29 Sep 2024 21:03:59 -0700 Subject: [PATCH 2/4] handle openai multiple function calls --- examples/foundational/14-function-calling.py | 11 ++-- src/pipecat/frames/frames.py | 1 + .../aggregators/openai_llm_context.py | 2 + src/pipecat/services/ai_services.py | 15 +++++- src/pipecat/services/openai.py | 54 +++++++++---------- 5 files changed, 50 insertions(+), 33 deletions(-) diff --git a/examples/foundational/14-function-calling.py b/examples/foundational/14-function-calling.py index b5aba449c..9141029ca 100644 --- a/examples/foundational/14-function-calling.py +++ b/examples/foundational/14-function-calling.py @@ -34,7 +34,12 @@ async def start_fetch_weather(function_name, llm, context): - await llm.push_frame(TextFrame("Let me check on that.")) + # note: we can't push a frame to the LLM here. the bot + # can interrupt itself and/or cause audio overlapping glitches. + # possible question for Aleix and Chad about what the right way + # to trigger speech is, now, with the new queues/async/sync refactors. + # await llm.push_frame(TextFrame("Let me check on that.")) + logger.debug(f"Starting fetch_weather_from_api with function_name: {function_name}") async def fetch_weather_from_api(function_name, tool_call_id, args, llm, context, result_callback): @@ -106,11 +111,11 @@ async def main(): pipeline = Pipeline( [ - fl_in, + # fl_in, transport.input(), context_aggregator.user(), llm, - fl_out, + # fl_out, tts, transport.output(), context_aggregator.assistant(), diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index 8059b904b..f7faa8ef0 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -585,6 +585,7 @@ class FunctionCallResultFrame(DataFrame): tool_call_id: str arguments: str result: Any + run_llm: bool = True @dataclass diff --git a/src/pipecat/processors/aggregators/openai_llm_context.py b/src/pipecat/processors/aggregators/openai_llm_context.py index 83ec3e57f..4bf3f042c 100644 --- a/src/pipecat/processors/aggregators/openai_llm_context.py +++ b/src/pipecat/processors/aggregators/openai_llm_context.py @@ -133,6 +133,7 @@ async def call_function( tool_call_id: str, arguments: str, llm: FrameProcessor, + run_llm: bool = True, ) -> None: # Push a SystemFrame downstream. This frame will let our assistant context aggregator # know that we are in the middle of a function call. Some contexts/aggregators may @@ -153,6 +154,7 @@ async def function_call_result_callback(result): tool_call_id=tool_call_id, arguments=arguments, result=result, + run_llm=run_llm, ) ) diff --git a/src/pipecat/services/ai_services.py b/src/pipecat/services/ai_services.py index 5eadb475b..a46ad3fab 100644 --- a/src/pipecat/services/ai_services.py +++ b/src/pipecat/services/ai_services.py @@ -110,7 +110,13 @@ def has_function(self, function_name: str): return function_name in self._callbacks.keys() async def call_function( - self, *, context: OpenAILLMContext, tool_call_id: str, function_name: str, arguments: str + self, + *, + context: OpenAILLMContext, + tool_call_id: str, + function_name: str, + arguments: str, + run_llm: bool, ) -> None: f = None if function_name in self._callbacks.keys(): @@ -120,7 +126,12 @@ async def call_function( else: return None await context.call_function( - f, function_name=function_name, tool_call_id=tool_call_id, arguments=arguments, llm=self + f, + function_name=function_name, + tool_call_id=tool_call_id, + arguments=arguments, + llm=self, + run_llm=run_llm, ) # QUESTION FOR CB: maybe this isn't needed anymore? diff --git a/src/pipecat/services/openai.py b/src/pipecat/services/openai.py index b17dd7397..73dae4644 100644 --- a/src/pipecat/services/openai.py +++ b/src/pipecat/services/openai.py @@ -273,26 +273,21 @@ async def _process_context(self, context: OpenAILLMContext): functions_list.append(function_name) arguments_list.append(arguments) tool_id_list.append(tool_call_id) - for function_name, arguments, tool_id in zip( - functions_list, arguments_list, tool_id_list + + total_items = len(functions_list) + for index, (function_name, arguments, tool_id) in enumerate( + zip(functions_list, arguments_list, tool_id_list), start=1 ): if self.has_function(function_name): - await self._handle_function_call(context, tool_id, function_name, arguments) - else: - raise OpenAIUnhandledFunctionException( - f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function." + run_llm = index == total_items + arguments = json.loads(arguments) + await self.call_function( + context=context, + function_name=function_name, + arguments=arguments, + tool_call_id=tool_id, + run_llm=run_llm, ) - # re-prompt to get a human answer after all the functions are called - await self._process_context(context) - - async def _handle_function_call(self, context, tool_call_id, function_name, arguments): - arguments = json.loads(arguments) - await self.call_function( - context=context, - tool_call_id=tool_call_id, - function_name=function_name, - arguments=arguments, - ) async def _update_settings(self, frame: LLMUpdateSettingsFrame): if frame.model is not None: @@ -486,31 +481,34 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator): def __init__(self, user_context_aggregator: OpenAIUserContextAggregator, **kwargs): super().__init__(context=user_context_aggregator._context, **kwargs) self._user_context_aggregator = user_context_aggregator - self._function_call_in_progress = None + self._function_calls_in_progress = {} self._function_call_result = None async def process_frame(self, frame, direction): await super().process_frame(frame, direction) # See note above about not calling push_frame() here. if isinstance(frame, StartInterruptionFrame): - self._function_call_in_progress = None + self._function_calls_in_progress.clear() self._function_call_finished = None + logger.debug("clearing function calls in progress") elif isinstance(frame, FunctionCallInProgressFrame): - self._function_call_in_progress = frame + self._function_calls_in_progress[frame.tool_call_id] = frame + logger.debug( + f"FunctionCallInProgressFrame: {frame.tool_call_id} {self._function_calls_in_progress}" + ) elif isinstance(frame, FunctionCallResultFrame): - if ( - self._function_call_in_progress - and self._function_call_in_progress.tool_call_id == frame.tool_call_id - ): - self._function_call_in_progress = None + logger.debug( + f"FunctionCallResultFrame: {frame.tool_call_id} {self._function_calls_in_progress}" + ) + if frame.tool_call_id in self._function_calls_in_progress: + del self._function_calls_in_progress[frame.tool_call_id] self._function_call_result = frame # TODO-CB: Kwin wants us to refactor this out of here but I REFUSE await self._push_aggregation() else: logger.warning( - "FunctionCallResultFrame tool_call_id does not match FunctionCallInProgressFrame tool_call_id" + "FunctionCallResultFrame tool_call_id does not match any function call in progress" ) - self._function_call_in_progress = None self._function_call_result = None async def _push_aggregation(self): @@ -549,7 +547,7 @@ async def _push_aggregation(self): "tool_call_id": frame.tool_call_id, } ) - run_llm = True + run_llm = frame.run_llm else: self._context.add_message({"role": "assistant", "content": aggregation}) From 6ad3437fd2b3511b7aef91d4f9785b30ff0dc1ec Mon Sep 17 00:00:00 2001 From: Kwindla Hultman Kramer Date: Sun, 29 Sep 2024 21:10:21 -0700 Subject: [PATCH 3/4] throw error if the llm tries to call a function that's not registered --- src/pipecat/services/openai.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/pipecat/services/openai.py b/src/pipecat/services/openai.py index 73dae4644..8a032ea40 100644 --- a/src/pipecat/services/openai.py +++ b/src/pipecat/services/openai.py @@ -288,6 +288,10 @@ async def _process_context(self, context: OpenAILLMContext): tool_call_id=tool_id, run_llm=run_llm, ) + else: + raise OpenAIUnhandledFunctionException( + f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function." + ) async def _update_settings(self, frame: LLMUpdateSettingsFrame): if frame.model is not None: From 0499fe41e455c700c2de11a1eebb84cc3f71a573 Mon Sep 17 00:00:00 2001 From: Kwindla Hultman Kramer Date: Sun, 29 Sep 2024 21:12:09 -0700 Subject: [PATCH 4/4] get rid of some debug log lines used during development --- src/pipecat/services/openai.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/pipecat/services/openai.py b/src/pipecat/services/openai.py index 8a032ea40..49fd04371 100644 --- a/src/pipecat/services/openai.py +++ b/src/pipecat/services/openai.py @@ -494,16 +494,9 @@ async def process_frame(self, frame, direction): if isinstance(frame, StartInterruptionFrame): self._function_calls_in_progress.clear() self._function_call_finished = None - logger.debug("clearing function calls in progress") elif isinstance(frame, FunctionCallInProgressFrame): self._function_calls_in_progress[frame.tool_call_id] = frame - logger.debug( - f"FunctionCallInProgressFrame: {frame.tool_call_id} {self._function_calls_in_progress}" - ) elif isinstance(frame, FunctionCallResultFrame): - logger.debug( - f"FunctionCallResultFrame: {frame.tool_call_id} {self._function_calls_in_progress}" - ) if frame.tool_call_id in self._function_calls_in_progress: del self._function_calls_in_progress[frame.tool_call_id] self._function_call_result = frame