Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle parallel function calls for OpenAI LLMs #522

Merged
merged 4 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions examples/foundational/14-function-calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@


async def start_fetch_weather(function_name, llm, context):
await llm.push_frame(TextFrame("Let me check on that."))
# note: we can't push a frame to the LLM here. the bot
# can interrupt itself and/or cause audio overlapping glitches.
# possible question for Aleix and Chad about what the right way
# to trigger speech is, now, with the new queues/async/sync refactors.
# await llm.push_frame(TextFrame("Let me check on that."))
logger.debug(f"Starting fetch_weather_from_api with function_name: {function_name}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rule is... if a processor can push frames from a different task then it should be async. But function calls happen in the same task, no? If so, this should be OK.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realized that, by default, all processors should be async and not the other way around. It makes more sense and it's safer this way. That was my initial idea but then changed my mine. But I think that was the right idea. I'll make the change in the morning and I believe that line should be safe.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will rebase and test again after this change.



async def fetch_weather_from_api(function_name, tool_call_id, args, llm, context, result_callback):
Expand Down Expand Up @@ -106,11 +111,11 @@ async def main():

pipeline = Pipeline(
[
fl_in,
# fl_in,
transport.input(),
context_aggregator.user(),
llm,
fl_out,
# fl_out,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we just remove the frame loggers?

Copy link
Contributor

@chadbailey59 chadbailey59 Sep 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably. I want to leave them in an example or two so people know they exist, though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking the same thing as @chadbailey59 . No reason not to leave them in a few examples as, um, examples.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we leave them we should uncomment them I think.

tts,
transport.output(),
context_aggregator.assistant(),
Expand Down
1 change: 1 addition & 0 deletions src/pipecat/frames/frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,7 @@ class FunctionCallResultFrame(DataFrame):
tool_call_id: str
arguments: str
result: Any
run_llm: bool = True


@dataclass
Expand Down
2 changes: 2 additions & 0 deletions src/pipecat/processors/aggregators/openai_llm_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ async def call_function(
tool_call_id: str,
arguments: str,
llm: FrameProcessor,
run_llm: bool = True,
) -> None:
# Push a SystemFrame downstream. This frame will let our assistant context aggregator
# know that we are in the middle of a function call. Some contexts/aggregators may
Expand All @@ -153,6 +154,7 @@ async def function_call_result_callback(result):
tool_call_id=tool_call_id,
arguments=arguments,
result=result,
run_llm=run_llm,
)
)

Expand Down
15 changes: 13 additions & 2 deletions src/pipecat/services/ai_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,13 @@ def has_function(self, function_name: str):
return function_name in self._callbacks.keys()

async def call_function(
self, *, context: OpenAILLMContext, tool_call_id: str, function_name: str, arguments: str
self,
*,
context: OpenAILLMContext,
tool_call_id: str,
function_name: str,
arguments: str,
run_llm: bool,
) -> None:
f = None
if function_name in self._callbacks.keys():
Expand All @@ -120,7 +126,12 @@ async def call_function(
else:
return None
await context.call_function(
f, function_name=function_name, tool_call_id=tool_call_id, arguments=arguments, llm=self
f,
function_name=function_name,
tool_call_id=tool_call_id,
arguments=arguments,
llm=self,
run_llm=run_llm,
)

# QUESTION FOR CB: maybe this isn't needed anymore?
Expand Down
68 changes: 42 additions & 26 deletions src/pipecat/services/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,10 @@ async def _stream_chat_completions(
return chunks

async def _process_context(self, context: OpenAILLMContext):
functions_list = []
arguments_list = []
tool_id_list = []
func_idx = 0
function_name = ""
arguments = ""
tool_call_id = ""
Expand Down Expand Up @@ -242,6 +246,14 @@ async def _process_context(self, context: OpenAILLMContext):
# yield a frame containing the function name and the arguments.

tool_call = chunk.choices[0].delta.tool_calls[0]
if tool_call.index != func_idx:
functions_list.append(function_name)
arguments_list.append(arguments)
tool_id_list.append(tool_call_id)
function_name = ""
arguments = ""
tool_call_id = ""
func_idx += 1
if tool_call.function and tool_call.function.name:
function_name += tool_call.function.name
tool_call_id = tool_call.id
Expand All @@ -257,21 +269,29 @@ async def _process_context(self, context: OpenAILLMContext):
# the context, and re-prompt to get a chat answer. If we don't have a registered
# handler, raise an exception.
if function_name and arguments:
if self.has_function(function_name):
await self._handle_function_call(context, tool_call_id, function_name, arguments)
else:
raise OpenAIUnhandledFunctionException(
f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function."
)

async def _handle_function_call(self, context, tool_call_id, function_name, arguments):
arguments = json.loads(arguments)
await self.call_function(
context=context,
tool_call_id=tool_call_id,
function_name=function_name,
arguments=arguments,
)
# added to the list as last function name and arguments not added to the list
functions_list.append(function_name)
arguments_list.append(arguments)
tool_id_list.append(tool_call_id)

total_items = len(functions_list)
for index, (function_name, arguments, tool_id) in enumerate(
zip(functions_list, arguments_list, tool_id_list), start=1
):
if self.has_function(function_name):
run_llm = index == total_items
arguments = json.loads(arguments)
await self.call_function(
context=context,
function_name=function_name,
arguments=arguments,
tool_call_id=tool_id,
run_llm=run_llm,
)
else:
raise OpenAIUnhandledFunctionException(
f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function."
)

async def _update_settings(self, frame: LLMUpdateSettingsFrame):
if frame.model is not None:
Expand Down Expand Up @@ -465,31 +485,27 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
def __init__(self, user_context_aggregator: OpenAIUserContextAggregator, **kwargs):
super().__init__(context=user_context_aggregator._context, **kwargs)
self._user_context_aggregator = user_context_aggregator
self._function_call_in_progress = None
self._function_calls_in_progress = {}
self._function_call_result = None

async def process_frame(self, frame, direction):
await super().process_frame(frame, direction)
# See note above about not calling push_frame() here.
if isinstance(frame, StartInterruptionFrame):
self._function_call_in_progress = None
self._function_calls_in_progress.clear()
self._function_call_finished = None
elif isinstance(frame, FunctionCallInProgressFrame):
self._function_call_in_progress = frame
self._function_calls_in_progress[frame.tool_call_id] = frame
elif isinstance(frame, FunctionCallResultFrame):
if (
self._function_call_in_progress
and self._function_call_in_progress.tool_call_id == frame.tool_call_id
):
self._function_call_in_progress = None
if frame.tool_call_id in self._function_calls_in_progress:
del self._function_calls_in_progress[frame.tool_call_id]
self._function_call_result = frame
# TODO-CB: Kwin wants us to refactor this out of here but I REFUSE
await self._push_aggregation()
else:
logger.warning(
"FunctionCallResultFrame tool_call_id does not match FunctionCallInProgressFrame tool_call_id"
"FunctionCallResultFrame tool_call_id does not match any function call in progress"
)
self._function_call_in_progress = None
self._function_call_result = None

async def _push_aggregation(self):
Expand Down Expand Up @@ -528,7 +544,7 @@ async def _push_aggregation(self):
"tool_call_id": frame.tool_call_id,
}
)
run_llm = True
run_llm = frame.run_llm
else:
self._context.add_message({"role": "assistant", "content": aggregation})

Expand Down
Loading