Skip to content

Commit

Permalink
[PY] feat: Chat Completion - Tools (#1853)
Browse files Browse the repository at this point in the history
## Linked issues

closes: #1699

## Details

- providing feature support w/ function calling in [chat
completions](https://platform.openai.com/docs/guides/function-calling)

## Attestation Checklist

- [x] My code follows the style guidelines of this project

- I have checked for/fixed spelling, linting, and other errors
- I have commented my code for clarity
- I have made corresponding changes to the documentation (updating the
doc strings in the code is sufficient)
- My changes generate no new warnings
- I have added tests that validates my changes, and provides sufficient
test coverage. I have tested with:
  - Local testing
  - E2E testing in Teams
- New and existing unit tests pass locally with my changes

---------

Co-authored-by: Alex Acebo <aacebowork@gmail.com>
  • Loading branch information
lilyydu and aacebo authored Aug 13, 2024
1 parent 6174f28 commit 5fda47e
Show file tree
Hide file tree
Showing 21 changed files with 832 additions and 51 deletions.
5 changes: 5 additions & 0 deletions python/packages/ai/teams/ai/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,11 @@ async def run(
)
loop = len(output) > 0
state.temp.action_outputs[command.action] = output

# Set output for action call
if command.action_id:
loop = True
state.temp.action_outputs[command.action_id] = output or ""
else:
output = await self._actions[ActionTypes.UNKNOWN_ACTION].invoke(
context, state, plan, command.action
Expand Down
2 changes: 2 additions & 0 deletions python/packages/ai/teams/ai/augmentations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
from .default_augmentation import DefaultAugmentation
from .monologue_augmentation import MonologueAugmentation
from .sequence_augmentation import SequenceAugmentation
from .tools_augmentation import ToolsAugmentation

__all__ = [
"Augmentation",
"DefaultAugmentation",
"MonologueAugmentation",
"SequenceAugmentation",
"ToolsAugmentation",
]
6 changes: 3 additions & 3 deletions python/packages/ai/teams/ai/augmentations/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Generic, TypeVar, Union
from typing import Generic, Optional, TypeVar

from botbuilder.core import TurnContext

Expand All @@ -27,12 +27,12 @@ class Augmentation(PromptResponseValidator, ABC, Generic[ValueT]):
"""

@abstractmethod
def create_prompt_section(self) -> Union[PromptSection, None]:
def create_prompt_section(self) -> Optional[PromptSection]:
"""
Creates an optional prompt section for the augmentation.
Returns:
Union[PromptSection, None]: The prompt section.
Optional[PromptSection]: The prompt section.
"""

@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from __future__ import annotations

from typing import TypeVar, Union
from typing import Optional, TypeVar

from botbuilder.core import TurnContext

Expand All @@ -28,7 +28,7 @@ class DefaultAugmentation(Augmentation[str]):
returns a `Plan` with a single `SAY` command containing the models response.
"""

def create_prompt_section(self) -> Union[PromptSection, None]:
def create_prompt_section(self) -> Optional[PromptSection]:
"""
Creates an optional prompt section for the augmentation.
"""
Expand Down
94 changes: 94 additions & 0 deletions python/packages/ai/teams/ai/augmentations/tools_augmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
"""

from __future__ import annotations

import json
from typing import List, Optional

from botbuilder.core import TurnContext

from ...state import MemoryBase
from ..models.prompt_response import PromptResponse
from ..planners.plan import (
Plan,
PredictedCommand,
PredictedDoCommand,
PredictedSayCommand,
)
from ..prompts.sections.prompt_section import PromptSection
from ..tokenizers.tokenizer import Tokenizer
from ..validators.validation import Validation
from .augmentation import Augmentation


class ToolsAugmentation(Augmentation[str]):
"""
A server-side 'tools' augmentation.
"""

def create_prompt_section(self) -> Optional[PromptSection]:
"""
Creates an optional prompt section for the augmentation.
"""
return None

async def validate_response(
self,
context: TurnContext,
memory: MemoryBase,
tokenizer: Tokenizer,
response: PromptResponse[str],
remaining_attempts: int,
) -> Validation:
"""
Validates a response to a prompt.
Args:
context (TurnContext): Context for the current turn of conversation.
memory (MemoryBase): Interface for accessing state variables.
tokenizer (Tokenizer): Tokenizer to use for encoding/decoding text.
response (PromptResponse[str]): Response to validate.
remaining_attempts (int): Nubmer of remaining attempts to validate the response.
Returns:
Validation: A 'Validation' object.
"""
return Validation(valid=True)

async def create_plan_from_response(
self,
turn_context: TurnContext,
memory: MemoryBase,
response: PromptResponse[str],
) -> Plan:
"""
Creates a plan given validated response value.
Args:
turn_context (TurnContext): Context for the current turn of conversation.
memory (MemoryBase): Interface for accessing state variables.
response (PromptResponse[str]):
The validated and transformed response for the prompt.
Returns:
Plan: The created plan.
"""

commands: List[PredictedCommand] = []

if response.message and response.message.action_calls:
tool_calls = response.message.action_calls

for tool in tool_calls:
command = PredictedDoCommand(
action=tool.function.name,
parameters=json.loads(tool.function.arguments),
action_id=tool.id,
)
commands.append(command)
return Plan(commands=commands)

return Plan(commands=[PredictedSayCommand(response=response.message)])
23 changes: 9 additions & 14 deletions python/packages/ai/teams/ai/clients/llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from __future__ import annotations

import json
from dataclasses import dataclass, field
from logging import Logger
from typing import Any, List, Optional
Expand Down Expand Up @@ -91,27 +90,19 @@ def __init__(self, options: LLMClientOptions) -> None:

self._options = options

def add_function_result_to_history(self, memory: MemoryBase, name: str, results: Any) -> None:
def add_action_output_to_history(self, memory: MemoryBase, id: str, results: str) -> None:
"""
Adds a result from a `function_call` to the history.
Adds the result from an `action_call` to the history.
Args:
memory (MemoryBase): An interface for accessing state values.
name (str): Name of the function that was called.
results (Any): Results returned by the function.
id (str): Id of the action that was called.
results (str): Results returned by the action call.
"""

content = ""

if isinstance(results, object):
content = json.dumps(results)
else:
content = str(results)

self._add_message_to_history(
memory=memory,
variable=self._options.history_variable,
message=Message(role="function", name=name, content=content),
message=Message(role="tool", action_call_id=id, content=results),
)

async def complete_prompt(
Expand Down Expand Up @@ -224,4 +215,8 @@ def _add_message_to_history(
if len(history) > self._options.max_history_messages:
del history[0 : len(history) - self._options.max_history_messages]

# Remove completed partial action outputs
while history and history[0].role == "tool":
del history[0]

memory.set(variable, history)
90 changes: 86 additions & 4 deletions python/packages/ai/teams/ai/models/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@

import openai
from botbuilder.core import TurnContext
from openai.types import chat
from openai import NOT_GIVEN
from openai.types import chat, shared_params
from openai.types.chat.chat_completion_message_tool_call_param import Function

from ...state import MemoryBase
from ..prompts.message import Message, MessageContext
from ..prompts.message import ActionCall, ActionFunction, Message, MessageContext
from ..prompts.prompt_functions import PromptFunctions
from ..prompts.prompt_template import PromptTemplate
from ..tokenizers import Tokenizer
Expand Down Expand Up @@ -124,11 +126,43 @@ async def complete_prompt(
template: PromptTemplate,
) -> PromptResponse[str]:
max_input_tokens = template.config.completion.max_input_tokens

# Setup tools if enabled
is_tools_aug = (
template.config.augmentation
and template.config.augmentation.augmentation_type == "tools"
)
tool_choice = (
template.config.completion.tool_choice
if template.config.completion.tool_choice is not None
else "auto"
)
parallel_tool_calls = (
template.config.completion.parallel_tool_calls
if template.config.completion.parallel_tool_calls is not None
else True
)
tools: List[chat.ChatCompletionToolParam] = []

# If tools is enabled, reformat actions to schema
if is_tools_aug and template.actions:
for action in template.actions:
curr_tool = chat.ChatCompletionToolParam(
type="function",
function=shared_params.FunctionDefinition(
name=action.name,
description=action.description or "",
parameters=action.parameters or {},
),
)
tools.append(curr_tool)

model = (
template.config.completion.model
if template.config.completion.model is not None
else self._options.default_model
)

res = await template.prompt.render_as_messages(
context=context,
memory=memory,
Expand Down Expand Up @@ -156,22 +190,47 @@ async def complete_prompt(
chat.ChatCompletionUserMessageParam,
chat.ChatCompletionAssistantMessageParam,
chat.ChatCompletionSystemMessageParam,
chat.ChatCompletionToolMessageParam,
] = chat.ChatCompletionUserMessageParam(
role="user",
content=msg.content if msg.content is not None else "",
)

if msg.name:
param["name"] = msg.name
setattr(param, "name", msg.name)

if msg.role == "assistant":
param = chat.ChatCompletionAssistantMessageParam(
role="assistant",
content=msg.content if msg.content is not None else "",
)

tool_call_params: List[chat.ChatCompletionMessageToolCallParam] = []

if msg.action_calls and len(msg.action_calls) > 0:
for tool_call in msg.action_calls:
tool_call_params.append(
chat.ChatCompletionMessageToolCallParam(
id=tool_call.id,
function=Function(
name=tool_call.function.name,
arguments=tool_call.function.arguments,
),
type=tool_call.type,
)
)
param["content"] = None
param["tool_calls"] = tool_call_params

if msg.name:
param["name"] = msg.name

elif msg.role == "tool":
param = chat.ChatCompletionToolMessageParam(
role="tool",
tool_call_id=msg.action_call_id if msg.action_call_id else "",
content=msg.content if msg.content else "",
)
elif msg.role == "system":
param = chat.ChatCompletionSystemMessageParam(
role="system",
Expand All @@ -187,6 +246,7 @@ async def complete_prompt(
extra_body = {}
if template.config.completion.data_sources is not None:
extra_body["data_sources"] = template.config.completion.data_sources

completion = await self._client.chat.completions.create(
messages=messages,
model=model,
Expand All @@ -195,24 +255,46 @@ async def complete_prompt(
top_p=template.config.completion.top_p,
temperature=template.config.completion.temperature,
max_tokens=template.config.completion.max_tokens,
tools=tools if len(tools) > 0 else NOT_GIVEN,
tool_choice=tool_choice if len(tools) > 0 else NOT_GIVEN,
parallel_tool_calls=parallel_tool_calls if len(tools) > 0 else NOT_GIVEN,
extra_body=extra_body,
)

if self._options.logger is not None:
self._options.logger.debug("COMPLETION:\n%s", completion.model_dump_json())

# Handle tools flow
action_calls = []
response_message = completion.choices[0].message
tool_calls = response_message.tool_calls

if is_tools_aug and tool_calls:
for curr_tool_call in tool_calls:
action_calls.append(
ActionCall(
id=curr_tool_call.id,
type=curr_tool_call.type,
function=ActionFunction(
name=curr_tool_call.function.name,
arguments=curr_tool_call.function.arguments,
),
)
)

input: Optional[Message] = None
last_message = len(res.output) - 1

# Skips the first message which is the prompt
if last_message > 0 and res.output[last_message].role == "user":
if last_message > 0 and res.output[last_message].role != "assistant":
input = res.output[last_message]

return PromptResponse[str](
input=input,
message=Message(
role=completion.choices[0].message.role,
content=completion.choices[0].message.content,
action_calls=(action_calls if is_tools_aug and len(action_calls) > 0 else None),
context=(
MessageContext.from_dict(completion.choices[0].message.context)
if hasattr(completion.choices[0].message, "context")
Expand Down
Loading

0 comments on commit 5fda47e

Please sign in to comment.